/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! */ #include #include #include #include #include #include "utils.h" #include "mxnet-cpp/MxNetCpp.h" using namespace mxnet::cpp; Symbol AlexnetSymbol(int num_classes) { auto input_data = Symbol::Variable("data"); auto target_label = Symbol::Variable("label"); /*stage 1*/ auto conv1 = Operator("Convolution") .SetParam("kernel", Shape(11, 11)) .SetParam("num_filter", 96) .SetParam("stride", Shape(4, 4)) .SetParam("dilate", Shape(1, 1)) .SetParam("pad", Shape(0, 0)) .SetParam("num_group", 1) .SetParam("workspace", 512) .SetParam("no_bias", false) .SetInput("data", input_data) .CreateSymbol("conv1"); auto relu1 = Operator("Activation") .SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */ .SetInput("data", conv1) .CreateSymbol("relu1"); auto pool1 = Operator("Pooling") .SetParam("kernel", Shape(3, 3)) .SetParam("pool_type", "max") /*avg,max,sum */ .SetParam("global_pool", false) .SetParam("stride", Shape(2, 2)) .SetParam("pad", Shape(0, 0)) .SetInput("data", relu1) .CreateSymbol("pool1"); auto lrn1 = Operator("LRN") .SetParam("nsize", 5) .SetParam("alpha", 0.0001) .SetParam("beta", 0.75) .SetParam("knorm", 1) .SetInput("data", pool1) .CreateSymbol("lrn1"); /*stage 2*/ auto conv2 = Operator("Convolution") .SetParam("kernel", Shape(5, 5)) .SetParam("num_filter", 256) .SetParam("stride", Shape(1, 1)) .SetParam("dilate", Shape(1, 1)) .SetParam("pad", Shape(2, 2)) .SetParam("num_group", 1) .SetParam("workspace", 512) .SetParam("no_bias", false) .SetInput("data", lrn1) .CreateSymbol("conv2"); auto relu2 = Operator("Activation") .SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */ .SetInput("data", conv2) .CreateSymbol("relu2"); auto pool2 = Operator("Pooling") .SetParam("kernel", Shape(3, 3)) .SetParam("pool_type", "max") /*avg,max,sum */ .SetParam("global_pool", false) .SetParam("stride", Shape(2, 2)) .SetParam("pad", Shape(0, 0)) .SetInput("data", relu2) .CreateSymbol("pool2"); auto lrn2 = Operator("LRN") .SetParam("nsize", 5) .SetParam("alpha", 0.0001) .SetParam("beta", 0.75) .SetParam("knorm", 1) .SetInput("data", pool2) .CreateSymbol("lrn2"); /*stage 3*/ auto conv3 = Operator("Convolution") .SetParam("kernel", Shape(3, 3)) .SetParam("num_filter", 384) .SetParam("stride", Shape(1, 1)) .SetParam("dilate", Shape(1, 1)) .SetParam("pad", Shape(1, 1)) .SetParam("num_group", 1) .SetParam("workspace", 512) .SetParam("no_bias", false) .SetInput("data", lrn2) .CreateSymbol("conv3"); auto relu3 = Operator("Activation") .SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */ .SetInput("data", conv3) .CreateSymbol("relu3"); auto conv4 = Operator("Convolution") .SetParam("kernel", Shape(3, 3)) .SetParam("num_filter", 384) .SetParam("stride", Shape(1, 1)) .SetParam("dilate", Shape(1, 1)) .SetParam("pad", Shape(1, 1)) .SetParam("num_group", 1) .SetParam("workspace", 512) .SetParam("no_bias", false) .SetInput("data", relu3) .CreateSymbol("conv4"); auto relu4 = Operator("Activation") .SetParam("act_type", "relu") /*relu,sigmoid,softrelu,tanh */ .SetInput("data", conv4) .CreateSymbol("relu4"); auto conv5 = Operator("Convolution") .SetParam("kernel", Shape(3, 3)) .SetParam("num_filter", 256) .SetParam("stride", Shape(1, 1)) .SetParam("dilate", Shape(1, 1)) .SetParam("pad", Shape(1, 1)) .SetParam("num_group", 1) .SetParam("workspace", 512) .SetParam("no_bias", false) .SetInput("data", relu4) .CreateSymbol("conv5"); auto relu5 = Operator("Activation") .SetParam("act_type", "relu") .SetInput("data", conv5) .CreateSymbol("relu5"); auto pool3 = Operator("Pooling") .SetParam("kernel", Shape(3, 3)) .SetParam("pool_type", "max") .SetParam("global_pool", false) .SetParam("stride", Shape(2, 2)) .SetParam("pad", Shape(0, 0)) .SetInput("data", relu5) .CreateSymbol("pool3"); /*stage4*/ auto flatten = Operator("Flatten").SetInput("data", pool3).CreateSymbol("flatten"); auto fc1 = Operator("FullyConnected") .SetParam("num_hidden", 4096) .SetParam("no_bias", false) .SetInput("data", flatten) .CreateSymbol("fc1"); auto relu6 = Operator("Activation") .SetParam("act_type", "relu") .SetInput("data", fc1) .CreateSymbol("relu6"); auto dropout1 = Operator("Dropout") .SetParam("p", 0.5) .SetInput("data", relu6) .CreateSymbol("dropout1"); /*stage5*/ auto fc2 = Operator("FullyConnected") .SetParam("num_hidden", 4096) .SetParam("no_bias", false) .SetInput("data", dropout1) .CreateSymbol("fc2"); auto relu7 = Operator("Activation") .SetParam("act_type", "relu") .SetInput("data", fc2) .CreateSymbol("relu7"); auto dropout2 = Operator("Dropout") .SetParam("p", 0.5) .SetInput("data", relu7) .CreateSymbol("dropout2"); /*stage6*/ auto fc3 = Operator("FullyConnected") .SetParam("num_hidden", num_classes) .SetParam("no_bias", false) .SetInput("data", dropout2) .CreateSymbol("fc3"); auto softmax = Operator("SoftmaxOutput") .SetParam("grad_scale", 1) .SetParam("ignore_label", -1) .SetParam("multi_output", false) .SetParam("use_ignore", false) .SetParam("normalization", "null") /*batch,null,valid */ .SetInput("data", fc3) .SetInput("label", target_label) .CreateSymbol("softmax"); return softmax; } NDArray ResizeInput(NDArray data, const Shape new_shape) { NDArray pic = data.Reshape(Shape(0, 1, 28, 28)); NDArray pic_1channel; Operator("_contrib_BilinearResize2D") .SetParam("height", new_shape[2]) .SetParam("width", new_shape[3]) (pic).Invoke(pic_1channel); NDArray output; Operator("tile") .SetParam("reps", Shape(1, 3, 1, 1)) (pic_1channel).Invoke(output); return output; } int main(int argc, char const *argv[]) { /*basic config*/ int max_epo = argc > 1 ? strtol(argv[1], nullptr, 10) : 100; float learning_rate = 1e-4; float weight_decay = 1e-4; /*context*/ auto ctx = Context::cpu(); int num_gpu; MXGetGPUCount(&num_gpu); int batch_size = 32; #if MXNET_USE_CUDA if (num_gpu > 0) { ctx = Context::gpu(); batch_size = 256; } #endif TRY /*net symbol*/ auto Net = AlexnetSymbol(10); /*args_map and aux_map is used for parameters' saving*/ std::map args_map; std::map aux_map; /*we should tell mxnet the shape of data and label*/ const Shape data_shape = Shape(batch_size, 3, 256, 256), label_shape = Shape(batch_size); args_map["data"] = NDArray(data_shape, ctx); args_map["label"] = NDArray(label_shape, ctx); /*with data and label, executor can be generated automatically*/ auto *exec = Net.SimpleBind(ctx, args_map); auto arg_names = Net.ListArguments(); aux_map = exec->aux_dict(); args_map = exec->arg_dict(); /*if fine tune from some pre-trained model, we should load the parameters*/ // NDArray::Load("./model/alex_params_3", nullptr, &args_map); /*else, we should use initializer Xavier to init the params*/ auto initializer = Uniform(0.07); for (auto &arg : args_map) { /*be careful here, the arg's name must has some specific ends or starts for * initializer to call*/ initializer(arg.first, &arg.second); } /*these binary files should be generated using im2rc tools, which can be found * in mxnet/bin*/ std::vector data_files = { "./data/mnist_data/train-images-idx3-ubyte", "./data/mnist_data/train-labels-idx1-ubyte", "./data/mnist_data/t10k-images-idx3-ubyte", "./data/mnist_data/t10k-labels-idx1-ubyte" }; auto train_iter = MXDataIter("MNISTIter"); if (!setDataIter(&train_iter, "Train", data_files, batch_size)) { return 1; } auto val_iter = MXDataIter("MNISTIter"); if (!setDataIter(&val_iter, "Label", data_files, batch_size)) { return 1; } Optimizer* opt = OptimizerRegistry::Find("sgd"); opt->SetParam("momentum", 0.9) ->SetParam("rescale_grad", 1.0 / batch_size) ->SetParam("clip_gradient", 10) ->SetParam("lr", learning_rate) ->SetParam("wd", weight_decay); Accuracy acu_train, acu_val; LogLoss logloss_train, logloss_val; for (int epoch = 0; epoch < max_epo; ++epoch) { LG << "Train Epoch: " << epoch; /*reset the metric every epoch*/ acu_train.Reset(); /*reset the data iter every epoch*/ train_iter.Reset(); int iter = 0; while (train_iter.Next()) { auto batch = train_iter.GetDataBatch(); /*use copyto to feed new data and label to the executor*/ ResizeInput(batch.data, data_shape).CopyTo(&args_map["data"]); batch.label.CopyTo(&args_map["label"]); exec->Forward(true); exec->Backward(); for (size_t i = 0; i < arg_names.size(); ++i) { if (arg_names[i] == "data" || arg_names[i] == "label") continue; opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]); } NDArray::WaitAll(); acu_train.Update(batch.label, exec->outputs[0]); logloss_train.Reset(); logloss_train.Update(batch.label, exec->outputs[0]); ++iter; LG << "EPOCH: " << epoch << " ITER: " << iter << " Train Accuracy: " << acu_train.Get() << " Train Loss: " << logloss_train.Get(); } LG << "EPOCH: " << epoch << " Train Accuracy: " << acu_train.Get(); LG << "Val Epoch: " << epoch; acu_val.Reset(); val_iter.Reset(); logloss_val.Reset(); iter = 0; while (val_iter.Next()) { auto batch = val_iter.GetDataBatch(); ResizeInput(batch.data, data_shape).CopyTo(&args_map["data"]); batch.label.CopyTo(&args_map["label"]); exec->Forward(false); NDArray::WaitAll(); acu_val.Update(batch.label, exec->outputs[0]); logloss_val.Update(batch.label, exec->outputs[0]); LG << "EPOCH: " << epoch << " ITER: " << iter << " Val Accuracy: " << acu_val.Get(); ++iter; } LG << "EPOCH: " << epoch << " Val Accuracy: " << acu_val.Get(); LG << "EPOCH: " << epoch << " Val LogLoss: " << logloss_val.Get(); /*save the parameters*/ std::stringstream ss; ss << epoch; std::string epoch_str; ss >> epoch_str; std::string save_path_param = "alex_param_" + epoch_str; auto save_args = args_map; /*we do not want to save the data and label*/ save_args.erase(save_args.find("data")); save_args.erase(save_args.find("label")); /*the alexnet does not get any aux array, so we do not need to save * aux_map*/ LG << "EPOCH: " << epoch << " Saving to..." << save_path_param; NDArray::Save(save_path_param, save_args); } /*don't foget to release the executor*/ delete exec; delete opt; MXNotifyShutdown(); CATCH return 0; }