# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: skip-file import sys sys.path.insert(0, '../../python') import mxnet as mx from mxnet.test_utils import get_mnist_ubyte import numpy as np import os, pickle, gzip, argparse import logging def get_model(use_gpu): # symbol net data = mx.symbol.Variable('data') conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2)) bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1") act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu") mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max') conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2)) bn2 = mx.symbol.BatchNorm(data = conv2, name="bn2") act2 = mx.symbol.Activation(data = bn2, name='relu2', act_type="relu") mp2 = mx.symbol.Pooling(data = act2, name = 'mp2', kernel=(2,2), stride=(2,2), pool_type='max') fl = mx.symbol.Flatten(data = mp2, name="flatten") fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', num_hidden=10) softmax = mx.symbol.SoftmaxOutput(data = fc2, name = 'sm') num_epoch = 1 ctx=mx.gpu() if use_gpu else mx.cpu() model = mx.model.FeedForward(softmax, ctx, num_epoch=num_epoch, learning_rate=0.1, wd=0.0001, momentum=0.9) return model def get_iters(): # check data get_mnist_ubyte() batch_size = 100 train_dataiter = mx.io.MNISTIter( image="data/train-images-idx3-ubyte", label="data/train-labels-idx1-ubyte", data_shape=(1, 28, 28), label_name='sm_label', batch_size=batch_size, shuffle=True, flat=False, silent=False, seed=10) val_dataiter = mx.io.MNISTIter( image="data/t10k-images-idx3-ubyte", label="data/t10k-labels-idx1-ubyte", data_shape=(1, 28, 28), label_name='sm_label', batch_size=batch_size, shuffle=True, flat=False, silent=False) return train_dataiter, val_dataiter # run default with unit test framework def test_mnist(): iters = get_iters() exec_mnist(get_model(False), iters[0], iters[1]) def exec_mnist(model, train_dataiter, val_dataiter): # print logging by default logging.basicConfig(level=logging.DEBUG) console = logging.StreamHandler() console.setLevel(logging.DEBUG) logging.getLogger('').addHandler(console) model.fit(X=train_dataiter, eval_data=val_dataiter) logging.info('Finish fit...') prob = model.predict(val_dataiter) logging.info('Finish predict...') val_dataiter.reset() y = np.concatenate([batch.label[0].asnumpy() for batch in val_dataiter]).astype('int') py = np.argmax(prob, axis=1) acc1 = float(np.sum(py == y)) / len(y) logging.info('final accuracy = %f', acc1) assert(acc1 > 0.94) # run as a script if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--gpu', action='store_true', help='use gpu to train') args = parser.parse_args() iters = get_iters() exec_mnist(get_model(args.gpu), iters[0], iters[1])