/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! */ #include #include #include #include #include "utils.h" #include "mxnet-cpp/MxNetCpp.h" using namespace mxnet::cpp; Symbol ConvFactoryBN(Symbol data, int num_filter, Shape kernel, Shape stride, Shape pad, const std::string & name, const std::string & suffix = "") { Symbol conv_w("conv_" + name + suffix + "_w"), conv_b("conv_" + name + suffix + "_b"); Symbol conv = Convolution("conv_" + name + suffix, data, conv_w, conv_b, kernel, num_filter, stride, Shape(1, 1), pad); std::string name_suffix = name + suffix; Symbol gamma(name_suffix + "_gamma"); Symbol beta(name_suffix + "_beta"); Symbol mmean(name_suffix + "_mmean"); Symbol mvar(name_suffix + "_mvar"); Symbol bn = BatchNorm("bn_" + name + suffix, conv, gamma, beta, mmean, mvar); return Activation("relu_" + name + suffix, bn, "relu"); } Symbol InceptionFactoryA(Symbol data, int num_1x1, int num_3x3red, int num_3x3, int num_d3x3red, int num_d3x3, PoolingPoolType pool, int proj, const std::string & name) { Symbol c1x1 = ConvFactoryBN(data, num_1x1, Shape(1, 1), Shape(1, 1), Shape(0, 0), name + "1x1"); Symbol c3x3r = ConvFactoryBN(data, num_3x3red, Shape(1, 1), Shape(1, 1), Shape(0, 0), name + "_3x3r"); Symbol c3x3 = ConvFactoryBN(c3x3r, num_3x3, Shape(3, 3), Shape(1, 1), Shape(1, 1), name + "_3x3"); Symbol cd3x3r = ConvFactoryBN(data, num_d3x3red, Shape(1, 1), Shape(1, 1), Shape(0, 0), name + "_double_3x3", "_reduce"); Symbol cd3x3 = ConvFactoryBN(cd3x3r, num_d3x3, Shape(3, 3), Shape(1, 1), Shape(1, 1), name + "_double_3x3_0"); cd3x3 = ConvFactoryBN(data = cd3x3, num_d3x3, Shape(3, 3), Shape(1, 1), Shape(1, 1), name + "_double_3x3_1"); Symbol pooling = Pooling(name + "_pool", data, Shape(3, 3), pool, false, false, PoolingPoolingConvention::kValid, Shape(1, 1), Shape(1, 1)); Symbol cproj = ConvFactoryBN(pooling, proj, Shape(1, 1), Shape(1, 1), Shape(0, 0), name + "_proj"); std::vector lst; lst.push_back(c1x1); lst.push_back(c3x3); lst.push_back(cd3x3); lst.push_back(cproj); return Concat("ch_concat_" + name + "_chconcat", lst, lst.size()); } Symbol InceptionFactoryB(Symbol data, int num_3x3red, int num_3x3, int num_d3x3red, int num_d3x3, const std::string & name) { Symbol c3x3r = ConvFactoryBN(data, num_3x3red, Shape(1, 1), Shape(1, 1), Shape(0, 0), name + "_3x3", "_reduce"); Symbol c3x3 = ConvFactoryBN(c3x3r, num_3x3, Shape(3, 3), Shape(2, 2), Shape(1, 1), name + "_3x3"); Symbol cd3x3r = ConvFactoryBN(data, num_d3x3red, Shape(1, 1), Shape(1, 1), Shape(0, 0), name + "_double_3x3", "_reduce"); Symbol cd3x3 = ConvFactoryBN(cd3x3r, num_d3x3, Shape(3, 3), Shape(1, 1), Shape(1, 1), name + "_double_3x3_0"); cd3x3 = ConvFactoryBN(cd3x3, num_d3x3, Shape(3, 3), Shape(2, 2), Shape(1, 1), name + "_double_3x3_1"); Symbol pooling = Pooling("max_pool_" + name + "_pool", data, Shape(3, 3), PoolingPoolType::kMax, false, false, PoolingPoolingConvention::kValid, Shape(2, 2), Shape(1, 1)); std::vector lst; lst.push_back(c3x3); lst.push_back(cd3x3); lst.push_back(pooling); return Concat("ch_concat_" + name + "_chconcat", lst, lst.size()); } Symbol InceptionSymbol(int num_classes) { // data and label Symbol data = Symbol::Variable("data"); Symbol data_label = Symbol::Variable("data_label"); // stage 1 Symbol conv1 = ConvFactoryBN(data, 64, Shape(7, 7), Shape(2, 2), Shape(3, 3), "conv1"); Symbol pool1 = Pooling("pool1", conv1, Shape(3, 3), PoolingPoolType::kMax, false, false, PoolingPoolingConvention::kValid, Shape(2, 2)); // stage 2 Symbol conv2red = ConvFactoryBN(pool1, 64, Shape(1, 1), Shape(1, 1), Shape(0, 0), "conv2red"); Symbol conv2 = ConvFactoryBN(conv2red, 192, Shape(3, 3), Shape(1, 1), Shape(1, 1), "conv2"); Symbol pool2 = Pooling("pool2", conv2, Shape(3, 3), PoolingPoolType::kMax, false, false, PoolingPoolingConvention::kValid, Shape(2, 2)); // stage 3 Symbol in3a = InceptionFactoryA(pool2, 64, 64, 64, 64, 96, PoolingPoolType::kAvg, 32, "3a"); Symbol in3b = InceptionFactoryA(in3a, 64, 64, 96, 64, 96, PoolingPoolType::kAvg, 64, "3b"); Symbol in3c = InceptionFactoryB(in3b, 128, 160, 64, 96, "3c"); // stage 4 Symbol in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, PoolingPoolType::kAvg, 128, "4a"); Symbol in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, PoolingPoolType::kAvg, 128, "4b"); Symbol in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, PoolingPoolType::kAvg, 128, "4c"); Symbol in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 192, PoolingPoolType::kAvg, 128, "4d"); Symbol in4e = InceptionFactoryB(in4d, 128, 192, 192, 256, "4e"); // stage 5 Symbol in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, PoolingPoolType::kAvg, 128, "5a"); Symbol in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, PoolingPoolType::kMax, 128, "5b"); // average pooling Symbol avg = Pooling("global_pool", in5b, Shape(7, 7), PoolingPoolType::kAvg); // classifier Symbol flatten = Flatten("flatten", avg); Symbol conv1_w("conv1_w"), conv1_b("conv1_b"); Symbol fc1 = FullyConnected("fc1", flatten, conv1_w, conv1_b, num_classes); return SoftmaxOutput("softmax", fc1, data_label); } NDArray ResizeInput(NDArray data, const Shape new_shape) { NDArray pic = data.Reshape(Shape(0, 1, 28, 28)); NDArray pic_1channel; Operator("_contrib_BilinearResize2D") .SetParam("height", new_shape[2]) .SetParam("width", new_shape[3]) (pic).Invoke(pic_1channel); NDArray output; Operator("tile") .SetParam("reps", Shape(1, 3, 1, 1)) (pic_1channel).Invoke(output); return output; } int main(int argc, char const *argv[]) { int batch_size = 40; int max_epoch = argc > 1 ? strtol(argv[1], nullptr, 10) : 100; float learning_rate = 1e-2; float weight_decay = 1e-4; /*context*/ auto ctx = Context::cpu(); int num_gpu; MXGetGPUCount(&num_gpu); #if MXNET_USE_CUDA if (num_gpu > 0) { ctx = Context::gpu(); } #endif TRY auto inception_bn_net = InceptionSymbol(10); std::map args_map; std::map aux_map; const Shape data_shape = Shape(batch_size, 3, 224, 224), label_shape = Shape(batch_size); args_map["data"] = NDArray(data_shape, ctx); args_map["data_label"] = NDArray(label_shape, ctx); inception_bn_net.InferArgsMap(ctx, &args_map, args_map); std::vector data_files = { "./data/mnist_data/train-images-idx3-ubyte", "./data/mnist_data/train-labels-idx1-ubyte", "./data/mnist_data/t10k-images-idx3-ubyte", "./data/mnist_data/t10k-labels-idx1-ubyte" }; auto train_iter = MXDataIter("MNISTIter"); if (!setDataIter(&train_iter, "Train", data_files, batch_size)) { return 1; } auto val_iter = MXDataIter("MNISTIter"); if (!setDataIter(&val_iter, "Label", data_files, batch_size)) { return 1; } // initialize parameters auto initializer = Uniform(0.07); for (auto& arg : args_map) { initializer(arg.first, &arg.second); } Optimizer* opt = OptimizerRegistry::Find("sgd"); opt->SetParam("momentum", 0.9) ->SetParam("rescale_grad", 1.0 / batch_size) ->SetParam("clip_gradient", 10) ->SetParam("lr", learning_rate) ->SetParam("wd", weight_decay); auto *exec = inception_bn_net.SimpleBind(ctx, args_map); auto arg_names = inception_bn_net.ListArguments(); // Create metrics Accuracy train_acc, val_acc; for (int iter = 0; iter < max_epoch; ++iter) { LG << "Epoch: " << iter; train_iter.Reset(); train_acc.Reset(); while (train_iter.Next()) { auto data_batch = train_iter.GetDataBatch(); ResizeInput(data_batch.data, data_shape).CopyTo(&args_map["data"]); data_batch.label.CopyTo(&args_map["data_label"]); NDArray::WaitAll(); exec->Forward(true); exec->Backward(); // Update parameters for (size_t i = 0; i < arg_names.size(); ++i) { if (arg_names[i] == "data" || arg_names[i] == "data_label") continue; opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]); } NDArray::WaitAll(); train_acc.Update(data_batch.label, exec->outputs[0]); } val_iter.Reset(); val_acc.Reset(); while (val_iter.Next()) { auto data_batch = val_iter.GetDataBatch(); ResizeInput(data_batch.data, data_shape).CopyTo(&args_map["data"]); data_batch.label.CopyTo(&args_map["data_label"]); NDArray::WaitAll(); exec->Forward(false); NDArray::WaitAll(); val_acc.Update(data_batch.label, exec->outputs[0]); } LG << "Train Accuracy: " << train_acc.Get(); LG << "Validation Accuracy: " << val_acc.Get(); } delete exec; delete opt; MXNotifyShutdown(); CATCH return 0; }