2017-08-11 14:12:47 -07:00
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
2017-03-22 11:55:51 +08:00
/*!
*/
# include <iostream>
# include <map>
# include <string>
2018-05-09 17:33:49 -07:00
# include <fstream>
# include <cstdlib>
# include "utils.h"
2017-03-22 11:55:51 +08:00
# include "mxnet-cpp/MxNetCpp.h"
2017-11-07 02:13:07 +08:00
2017-03-22 11:55:51 +08:00
using namespace mxnet : : cpp ;
Symbol AlexnetSymbol ( int num_classes ) {
auto input_data = Symbol : : Variable ( " data " ) ;
auto target_label = Symbol : : Variable ( " label " ) ;
/*stage 1*/
auto conv1 = Operator ( " Convolution " )
. SetParam ( " kernel " , Shape ( 11 , 11 ) )
. SetParam ( " num_filter " , 96 )
. SetParam ( " stride " , Shape ( 4 , 4 ) )
. SetParam ( " dilate " , Shape ( 1 , 1 ) )
. SetParam ( " pad " , Shape ( 0 , 0 ) )
. SetParam ( " num_group " , 1 )
. SetParam ( " workspace " , 512 )
. SetParam ( " no_bias " , false )
. SetInput ( " data " , input_data )
. CreateSymbol ( " conv1 " ) ;
auto relu1 = Operator ( " Activation " )
. SetParam ( " act_type " , " relu " ) /*relu,sigmoid,softrelu,tanh */
. SetInput ( " data " , conv1 )
. CreateSymbol ( " relu1 " ) ;
auto pool1 = Operator ( " Pooling " )
. SetParam ( " kernel " , Shape ( 3 , 3 ) )
. SetParam ( " pool_type " , " max " ) /*avg,max,sum */
. SetParam ( " global_pool " , false )
. SetParam ( " stride " , Shape ( 2 , 2 ) )
. SetParam ( " pad " , Shape ( 0 , 0 ) )
. SetInput ( " data " , relu1 )
. CreateSymbol ( " pool1 " ) ;
auto lrn1 = Operator ( " LRN " )
. SetParam ( " nsize " , 5 )
. SetParam ( " alpha " , 0.0001 )
. SetParam ( " beta " , 0.75 )
. SetParam ( " knorm " , 1 )
. SetInput ( " data " , pool1 )
. CreateSymbol ( " lrn1 " ) ;
/*stage 2*/
auto conv2 = Operator ( " Convolution " )
. SetParam ( " kernel " , Shape ( 5 , 5 ) )
. SetParam ( " num_filter " , 256 )
. SetParam ( " stride " , Shape ( 1 , 1 ) )
. SetParam ( " dilate " , Shape ( 1 , 1 ) )
. SetParam ( " pad " , Shape ( 2 , 2 ) )
. SetParam ( " num_group " , 1 )
. SetParam ( " workspace " , 512 )
. SetParam ( " no_bias " , false )
. SetInput ( " data " , lrn1 )
. CreateSymbol ( " conv2 " ) ;
auto relu2 = Operator ( " Activation " )
. SetParam ( " act_type " , " relu " ) /*relu,sigmoid,softrelu,tanh */
. SetInput ( " data " , conv2 )
. CreateSymbol ( " relu2 " ) ;
auto pool2 = Operator ( " Pooling " )
. SetParam ( " kernel " , Shape ( 3 , 3 ) )
. SetParam ( " pool_type " , " max " ) /*avg,max,sum */
. SetParam ( " global_pool " , false )
. SetParam ( " stride " , Shape ( 2 , 2 ) )
. SetParam ( " pad " , Shape ( 0 , 0 ) )
. SetInput ( " data " , relu2 )
. CreateSymbol ( " pool2 " ) ;
auto lrn2 = Operator ( " LRN " )
. SetParam ( " nsize " , 5 )
. SetParam ( " alpha " , 0.0001 )
. SetParam ( " beta " , 0.75 )
. SetParam ( " knorm " , 1 )
. SetInput ( " data " , pool2 )
. CreateSymbol ( " lrn2 " ) ;
/*stage 3*/
auto conv3 = Operator ( " Convolution " )
. SetParam ( " kernel " , Shape ( 3 , 3 ) )
. SetParam ( " num_filter " , 384 )
. SetParam ( " stride " , Shape ( 1 , 1 ) )
. SetParam ( " dilate " , Shape ( 1 , 1 ) )
. SetParam ( " pad " , Shape ( 1 , 1 ) )
. SetParam ( " num_group " , 1 )
. SetParam ( " workspace " , 512 )
. SetParam ( " no_bias " , false )
. SetInput ( " data " , lrn2 )
. CreateSymbol ( " conv3 " ) ;
auto relu3 = Operator ( " Activation " )
. SetParam ( " act_type " , " relu " ) /*relu,sigmoid,softrelu,tanh */
. SetInput ( " data " , conv3 )
. CreateSymbol ( " relu3 " ) ;
auto conv4 = Operator ( " Convolution " )
. SetParam ( " kernel " , Shape ( 3 , 3 ) )
. SetParam ( " num_filter " , 384 )
. SetParam ( " stride " , Shape ( 1 , 1 ) )
. SetParam ( " dilate " , Shape ( 1 , 1 ) )
. SetParam ( " pad " , Shape ( 1 , 1 ) )
. SetParam ( " num_group " , 1 )
. SetParam ( " workspace " , 512 )
. SetParam ( " no_bias " , false )
. SetInput ( " data " , relu3 )
. CreateSymbol ( " conv4 " ) ;
auto relu4 = Operator ( " Activation " )
. SetParam ( " act_type " , " relu " ) /*relu,sigmoid,softrelu,tanh */
. SetInput ( " data " , conv4 )
. CreateSymbol ( " relu4 " ) ;
auto conv5 = Operator ( " Convolution " )
. SetParam ( " kernel " , Shape ( 3 , 3 ) )
. SetParam ( " num_filter " , 256 )
. SetParam ( " stride " , Shape ( 1 , 1 ) )
. SetParam ( " dilate " , Shape ( 1 , 1 ) )
. SetParam ( " pad " , Shape ( 1 , 1 ) )
. SetParam ( " num_group " , 1 )
. SetParam ( " workspace " , 512 )
. SetParam ( " no_bias " , false )
. SetInput ( " data " , relu4 )
. CreateSymbol ( " conv5 " ) ;
auto relu5 = Operator ( " Activation " )
. SetParam ( " act_type " , " relu " )
. SetInput ( " data " , conv5 )
. CreateSymbol ( " relu5 " ) ;
auto pool3 = Operator ( " Pooling " )
. SetParam ( " kernel " , Shape ( 3 , 3 ) )
. SetParam ( " pool_type " , " max " )
. SetParam ( " global_pool " , false )
. SetParam ( " stride " , Shape ( 2 , 2 ) )
. SetParam ( " pad " , Shape ( 0 , 0 ) )
. SetInput ( " data " , relu5 )
. CreateSymbol ( " pool3 " ) ;
/*stage4*/
auto flatten =
Operator ( " Flatten " ) . SetInput ( " data " , pool3 ) . CreateSymbol ( " flatten " ) ;
auto fc1 = Operator ( " FullyConnected " )
. SetParam ( " num_hidden " , 4096 )
. SetParam ( " no_bias " , false )
. SetInput ( " data " , flatten )
. CreateSymbol ( " fc1 " ) ;
auto relu6 = Operator ( " Activation " )
. SetParam ( " act_type " , " relu " )
. SetInput ( " data " , fc1 )
. CreateSymbol ( " relu6 " ) ;
auto dropout1 = Operator ( " Dropout " )
. SetParam ( " p " , 0.5 )
. SetInput ( " data " , relu6 )
. CreateSymbol ( " dropout1 " ) ;
/*stage5*/
auto fc2 = Operator ( " FullyConnected " )
. SetParam ( " num_hidden " , 4096 )
. SetParam ( " no_bias " , false )
. SetInput ( " data " , dropout1 )
. CreateSymbol ( " fc2 " ) ;
auto relu7 = Operator ( " Activation " )
. SetParam ( " act_type " , " relu " )
. SetInput ( " data " , fc2 )
. CreateSymbol ( " relu7 " ) ;
auto dropout2 = Operator ( " Dropout " )
. SetParam ( " p " , 0.5 )
. SetInput ( " data " , relu7 )
. CreateSymbol ( " dropout2 " ) ;
/*stage6*/
auto fc3 = Operator ( " FullyConnected " )
. SetParam ( " num_hidden " , num_classes )
. SetParam ( " no_bias " , false )
. SetInput ( " data " , dropout2 )
. CreateSymbol ( " fc3 " ) ;
auto softmax = Operator ( " SoftmaxOutput " )
. SetParam ( " grad_scale " , 1 )
. SetParam ( " ignore_label " , - 1 )
. SetParam ( " multi_output " , false )
. SetParam ( " use_ignore " , false )
. SetParam ( " normalization " , " null " ) /*batch,null,valid */
. SetInput ( " data " , fc3 )
. SetInput ( " label " , target_label )
. CreateSymbol ( " softmax " ) ;
return softmax ;
}
2019-03-07 12:53:27 +08:00
NDArray ResizeInput ( NDArray data , const Shape new_shape ) {
NDArray pic = data . Reshape ( Shape ( 0 , 1 , 28 , 28 ) ) ;
NDArray pic_1channel ;
Operator ( " _contrib_BilinearResize2D " )
. SetParam ( " height " , new_shape [ 2 ] )
. SetParam ( " width " , new_shape [ 3 ] )
( pic ) . Invoke ( pic_1channel ) ;
NDArray output ;
Operator ( " tile " )
. SetParam ( " reps " , Shape ( 1 , 3 , 1 , 1 ) )
( pic_1channel ) . Invoke ( output ) ;
return output ;
}
2017-03-22 11:55:51 +08:00
int main ( int argc , char const * argv [ ] ) {
/*basic config*/
2020-02-09 02:50:49 +01:00
int max_epo = argc > 1 ? strtol ( argv [ 1 ] , nullptr , 10 ) : 100 ;
2017-03-22 11:55:51 +08:00
float learning_rate = 1e-4 ;
float weight_decay = 1e-4 ;
2019-03-07 12:53:27 +08:00
/*context*/
auto ctx = Context : : cpu ( ) ;
int num_gpu ;
MXGetGPUCount ( & num_gpu ) ;
int batch_size = 32 ;
# if !MXNET_USE_CPU
if ( num_gpu > 0 ) {
ctx = Context : : gpu ( ) ;
batch_size = 256 ;
}
2018-05-09 17:33:49 -07:00
# endif
2019-04-02 16:23:54 -07:00
TRY
2019-03-07 12:53:27 +08:00
/*net symbol*/
2017-03-22 11:55:51 +08:00
auto Net = AlexnetSymbol ( 10 ) ;
/*args_map and aux_map is used for parameters' saving*/
2018-05-09 17:33:49 -07:00
std : : map < std : : string , NDArray > args_map ;
std : : map < std : : string , NDArray > aux_map ;
2017-03-22 11:55:51 +08:00
/*we should tell mxnet the shape of data and label*/
2019-03-07 12:53:27 +08:00
const Shape data_shape = Shape ( batch_size , 3 , 256 , 256 ) ,
label_shape = Shape ( batch_size ) ;
args_map [ " data " ] = NDArray ( data_shape , ctx ) ;
args_map [ " label " ] = NDArray ( label_shape , ctx ) ;
2017-03-22 11:55:51 +08:00
/*with data and label, executor can be generated automatically*/
auto * exec = Net . SimpleBind ( ctx , args_map ) ;
2017-08-04 04:18:07 +08:00
auto arg_names = Net . ListArguments ( ) ;
2017-03-22 11:55:51 +08:00
aux_map = exec - > aux_dict ( ) ;
args_map = exec - > arg_dict ( ) ;
/*if fine tune from some pre-trained model, we should load the parameters*/
// NDArray::Load("./model/alex_params_3", nullptr, &args_map);
/*else, we should use initializer Xavier to init the params*/
Xavier xavier = Xavier ( Xavier : : gaussian , Xavier : : in , 2.34 ) ;
for ( auto & arg : args_map ) {
/*be careful here, the arg's name must has some specific ends or starts for
* initializer to call*/
xavier ( arg . first , & arg . second ) ;
}
/*these binary files should be generated using im2rc tools, which can be found
* in mxnet/bin*/
2018-05-09 17:33:49 -07:00
std : : vector < std : : string > data_files = { " ./data/mnist_data/train-images-idx3-ubyte " ,
" ./data/mnist_data/train-labels-idx1-ubyte " ,
" ./data/mnist_data/t10k-images-idx3-ubyte " ,
" ./data/mnist_data/t10k-labels-idx1-ubyte "
} ;
auto train_iter = MXDataIter ( " MNISTIter " ) ;
2019-01-17 01:40:23 -08:00
if ( ! setDataIter ( & train_iter , " Train " , data_files , batch_size ) ) {
return 1 ;
}
2018-05-09 17:33:49 -07:00
auto val_iter = MXDataIter ( " MNISTIter " ) ;
2019-01-17 01:40:23 -08:00
if ( ! setDataIter ( & val_iter , " Label " , data_files , batch_size ) ) {
return 1 ;
}
2017-03-22 11:55:51 +08:00
2018-11-27 09:48:25 -08:00
Optimizer * opt = OptimizerRegistry : : Find ( " sgd " ) ;
2017-03-22 11:55:51 +08:00
opt - > SetParam ( " momentum " , 0.9 )
- > SetParam ( " rescale_grad " , 1.0 / batch_size )
2017-08-04 04:18:07 +08:00
- > SetParam ( " clip_gradient " , 10 )
- > SetParam ( " lr " , learning_rate )
- > SetParam ( " wd " , weight_decay ) ;
2017-03-22 11:55:51 +08:00
Accuracy acu_train , acu_val ;
2019-03-07 12:53:27 +08:00
LogLoss logloss_train , logloss_val ;
for ( int epoch = 0 ; epoch < max_epo ; + + epoch ) {
LG < < " Train Epoch: " < < epoch ;
2017-03-22 11:55:51 +08:00
/*reset the metric every epoch*/
acu_train . Reset ( ) ;
/*reset the data iter every epoch*/
train_iter . Reset ( ) ;
2019-03-07 12:53:27 +08:00
int iter = 0 ;
2017-03-22 11:55:51 +08:00
while ( train_iter . Next ( ) ) {
auto batch = train_iter . GetDataBatch ( ) ;
/*use copyto to feed new data and label to the executor*/
2019-03-07 12:53:27 +08:00
ResizeInput ( batch . data , data_shape ) . CopyTo ( & args_map [ " data " ] ) ;
2017-03-22 11:55:51 +08:00
batch . label . CopyTo ( & args_map [ " label " ] ) ;
exec - > Forward ( true ) ;
exec - > Backward ( ) ;
2017-08-04 04:18:07 +08:00
for ( size_t i = 0 ; i < arg_names . size ( ) ; + + i ) {
if ( arg_names [ i ] = = " data " | | arg_names [ i ] = = " label " ) continue ;
opt - > Update ( i , exec - > arg_arrays [ i ] , exec - > grad_arrays [ i ] ) ;
}
2017-03-22 11:55:51 +08:00
NDArray : : WaitAll ( ) ;
acu_train . Update ( batch . label , exec - > outputs [ 0 ] ) ;
2019-03-07 12:53:27 +08:00
logloss_train . Reset ( ) ;
logloss_train . Update ( batch . label , exec - > outputs [ 0 ] ) ;
+ + iter ;
LG < < " EPOCH: " < < epoch < < " ITER: " < < iter
< < " Train Accuracy: " < < acu_train . Get ( )
< < " Train Loss: " < < logloss_train . Get ( ) ;
2017-03-22 11:55:51 +08:00
}
2019-03-07 12:53:27 +08:00
LG < < " EPOCH: " < < epoch < < " Train Accuracy: " < < acu_train . Get ( ) ;
2017-03-22 11:55:51 +08:00
2019-03-07 12:53:27 +08:00
LG < < " Val Epoch: " < < epoch ;
2017-03-22 11:55:51 +08:00
acu_val . Reset ( ) ;
val_iter . Reset ( ) ;
logloss_val . Reset ( ) ;
2019-03-07 12:53:27 +08:00
iter = 0 ;
2017-03-22 11:55:51 +08:00
while ( val_iter . Next ( ) ) {
auto batch = val_iter . GetDataBatch ( ) ;
2019-03-07 12:53:27 +08:00
ResizeInput ( batch . data , data_shape ) . CopyTo ( & args_map [ " data " ] ) ;
2017-03-22 11:55:51 +08:00
batch . label . CopyTo ( & args_map [ " label " ] ) ;
exec - > Forward ( false ) ;
NDArray : : WaitAll ( ) ;
acu_val . Update ( batch . label , exec - > outputs [ 0 ] ) ;
logloss_val . Update ( batch . label , exec - > outputs [ 0 ] ) ;
2019-03-07 12:53:27 +08:00
LG < < " EPOCH: " < < epoch < < " ITER: " < < iter < < " Val Accuracy: " < < acu_val . Get ( ) ;
+ + iter ;
2017-03-22 11:55:51 +08:00
}
2019-03-07 12:53:27 +08:00
LG < < " EPOCH: " < < epoch < < " Val Accuracy: " < < acu_val . Get ( ) ;
LG < < " EPOCH: " < < epoch < < " Val LogLoss: " < < logloss_val . Get ( ) ;
2017-03-22 11:55:51 +08:00
/*save the parameters*/
2018-05-09 17:33:49 -07:00
std : : stringstream ss ;
2019-03-07 12:53:27 +08:00
ss < < epoch ;
std : : string epoch_str ;
ss > > epoch_str ;
std : : string save_path_param = " alex_param_ " + epoch_str ;
2017-03-22 11:55:51 +08:00
auto save_args = args_map ;
/*we do not want to save the data and label*/
save_args . erase ( save_args . find ( " data " ) ) ;
save_args . erase ( save_args . find ( " label " ) ) ;
/*the alexnet does not get any aux array, so we do not need to save
* aux_map*/
2019-03-07 12:53:27 +08:00
LG < < " EPOCH: " < < epoch < < " Saving to... " < < save_path_param ;
2017-03-22 11:55:51 +08:00
NDArray : : Save ( save_path_param , save_args ) ;
}
/*don't foget to release the executor*/
delete exec ;
2019-03-03 12:37:11 +08:00
delete opt ;
2017-03-22 11:55:51 +08:00
MXNotifyShutdown ( ) ;
2019-04-02 16:23:54 -07:00
CATCH
2017-03-22 11:55:51 +08:00
return 0 ;
}