2017-08-11 14:12:47 -07:00
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
2017-03-22 11:55:51 +08:00
/*!
*/
# include <map>
# include <string>
2018-05-09 17:33:49 -07:00
# include <fstream>
2017-03-22 11:55:51 +08:00
# include <vector>
2018-05-09 17:33:49 -07:00
# include <cstdlib>
# include "utils.h"
2017-03-22 11:55:51 +08:00
# include "mxnet-cpp/MxNetCpp.h"
2017-11-07 02:13:07 +08:00
2017-03-22 11:55:51 +08:00
using namespace mxnet : : cpp ;
Symbol ConvolutionNoBias ( const std : : string & symbol_name ,
Symbol data ,
Symbol weight ,
Shape kernel ,
int num_filter ,
Shape stride = Shape ( 1 , 1 ) ,
Shape dilate = Shape ( 1 , 1 ) ,
Shape pad = Shape ( 0 , 0 ) ,
int num_group = 1 ,
int64_t workspace = 512 ) {
return Operator ( " Convolution " )
. SetParam ( " kernel " , kernel )
. SetParam ( " num_filter " , num_filter )
. SetParam ( " stride " , stride )
. SetParam ( " dilate " , dilate )
. SetParam ( " pad " , pad )
. SetParam ( " num_group " , num_group )
. SetParam ( " workspace " , workspace )
. SetParam ( " no_bias " , true )
. SetInput ( " data " , data )
. SetInput ( " weight " , weight )
. CreateSymbol ( symbol_name ) ;
}
Symbol getConv ( const std : : string & name , Symbol data ,
int num_filter ,
Shape kernel , Shape stride , Shape pad ,
bool with_relu ,
mx_float bn_momentum ) {
Symbol conv_w ( name + " _w " ) ;
Symbol conv = ConvolutionNoBias ( name , data , conv_w ,
kernel , num_filter , stride , Shape ( 1 , 1 ) ,
pad , 1 , 512 ) ;
2017-05-27 00:44:41 +08:00
Symbol gamma ( name + " _gamma " ) ;
Symbol beta ( name + " _beta " ) ;
Symbol mmean ( name + " _mmean " ) ;
Symbol mvar ( name + " _mvar " ) ;
Symbol bn = BatchNorm ( name + " _bn " , conv , gamma ,
beta , mmean , mvar , 2e-5 , bn_momentum , false ) ;
2017-03-22 11:55:51 +08:00
if ( with_relu ) {
return Activation ( name + " _relu " , bn , " relu " ) ;
} else {
return bn ;
}
}
Symbol makeBlock ( const std : : string & name , Symbol data , int num_filter ,
bool dim_match , mx_float bn_momentum ) {
Shape stride ;
if ( dim_match ) {
stride = Shape ( 1 , 1 ) ;
} else {
stride = Shape ( 2 , 2 ) ;
}
Symbol conv1 = getConv ( name + " _conv1 " , data , num_filter ,
Shape ( 3 , 3 ) , stride , Shape ( 1 , 1 ) ,
true , bn_momentum ) ;
Symbol conv2 = getConv ( name + " _conv2 " , conv1 , num_filter ,
Shape ( 3 , 3 ) , Shape ( 1 , 1 ) , Shape ( 1 , 1 ) ,
false , bn_momentum ) ;
Symbol shortcut ;
if ( dim_match ) {
shortcut = data ;
} else {
Symbol shortcut_w ( name + " _proj_w " ) ;
shortcut = ConvolutionNoBias ( name + " _proj " , data , shortcut_w ,
Shape ( 2 , 2 ) , num_filter ,
Shape ( 2 , 2 ) , Shape ( 1 , 1 ) , Shape ( 0 , 0 ) ,
1 , 512 ) ;
}
Symbol fused = shortcut + conv2 ;
return Activation ( name + " _relu " , fused , " relu " ) ;
}
Symbol getBody ( Symbol data , int num_level , int num_block , int num_filter , mx_float bn_momentum ) {
for ( int level = 0 ; level < num_level ; level + + ) {
for ( int block = 0 ; block < num_block ; block + + ) {
data = makeBlock ( " level " + std : : to_string ( level + 1 ) + " _block " + std : : to_string ( block + 1 ) ,
data , num_filter * ( std : : pow ( 2 , level ) ) ,
( level = = 0 | | block > 0 ) , bn_momentum ) ;
}
}
return data ;
}
Symbol ResNetSymbol ( int num_class , int num_level = 3 , int num_block = 9 ,
int num_filter = 16 , mx_float bn_momentum = 0.9 ,
mxnet : : cpp : : Shape pool_kernel = mxnet : : cpp : : Shape ( 8 , 8 ) ) {
// data and label
Symbol data = Symbol : : Variable ( " data " ) ;
Symbol data_label = Symbol : : Variable ( " data_label " ) ;
2017-05-27 00:44:41 +08:00
Symbol gamma ( " gamma " ) ;
Symbol beta ( " beta " ) ;
Symbol mmean ( " mmean " ) ;
Symbol mvar ( " mvar " ) ;
Symbol zscore = BatchNorm ( " zscore " , data , gamma ,
beta , mmean , mvar , 0.001 , bn_momentum ) ;
2017-03-22 11:55:51 +08:00
Symbol conv = getConv ( " conv0 " , zscore , num_filter ,
Shape ( 3 , 3 ) , Shape ( 1 , 1 ) , Shape ( 1 , 1 ) ,
true , bn_momentum ) ;
Symbol body = getBody ( conv , num_level , num_block , num_filter , bn_momentum ) ;
2017-05-19 23:49:41 +08:00
Symbol pool = Pooling ( " pool " , body , pool_kernel , PoolingPoolType : : kAvg ) ;
2017-03-22 11:55:51 +08:00
Symbol flat = Flatten ( " flatten " , pool ) ;
Symbol fc_w ( " fc_w " ) , fc_b ( " fc_b " ) ;
Symbol fc = FullyConnected ( " fc " , flat , fc_w , fc_b , num_class ) ;
return SoftmaxOutput ( " softmax " , fc , data_label ) ;
}
2019-03-07 12:53:27 +08:00
NDArray ResizeInput ( NDArray data , const Shape new_shape ) {
NDArray pic = data . Reshape ( Shape ( 0 , 1 , 28 , 28 ) ) ;
NDArray pic_1channel ;
Operator ( " _contrib_BilinearResize2D " )
. SetParam ( " height " , new_shape [ 2 ] )
. SetParam ( " width " , new_shape [ 3 ] )
( pic ) . Invoke ( pic_1channel ) ;
NDArray output ;
Operator ( " tile " )
. SetParam ( " reps " , Shape ( 1 , 3 , 1 , 1 ) )
( pic_1channel ) . Invoke ( output ) ;
return output ;
}
2017-03-22 11:55:51 +08:00
int main ( int argc , char const * argv [ ] ) {
2020-02-09 02:50:49 +01:00
int max_epoch = argc > 1 ? strtol ( argv [ 1 ] , nullptr , 10 ) : 100 ;
2017-03-22 11:55:51 +08:00
float learning_rate = 1e-4 ;
float weight_decay = 1e-4 ;
2019-04-02 16:23:54 -07:00
TRY
2017-03-22 11:55:51 +08:00
auto resnet = ResNetSymbol ( 10 ) ;
std : : map < std : : string , NDArray > args_map ;
std : : map < std : : string , NDArray > aux_map ;
2019-03-07 12:53:27 +08:00
/*context*/
auto ctx = Context : : cpu ( ) ;
int num_gpu ;
MXGetGPUCount ( & num_gpu ) ;
int batch_size = 8 ;
# if !MXNET_USE_CPU
if ( num_gpu > 0 ) {
ctx = Context : : gpu ( ) ;
2019-04-08 00:21:36 -07:00
batch_size = 32 ;
2019-03-07 12:53:27 +08:00
}
2018-05-09 17:33:49 -07:00
# endif
2019-03-07 12:53:27 +08:00
const Shape data_shape = Shape ( batch_size , 3 , 224 , 224 ) ,
label_shape = Shape ( batch_size ) ;
args_map [ " data " ] = NDArray ( data_shape , ctx ) ;
args_map [ " data_label " ] = NDArray ( label_shape , ctx ) ;
2018-05-09 17:33:49 -07:00
resnet . InferArgsMap ( ctx , & args_map , args_map ) ;
std : : vector < std : : string > data_files = { " ./data/mnist_data/train-images-idx3-ubyte " ,
" ./data/mnist_data/train-labels-idx1-ubyte " ,
" ./data/mnist_data/t10k-images-idx3-ubyte " ,
" ./data/mnist_data/t10k-labels-idx1-ubyte "
} ;
auto train_iter = MXDataIter ( " MNISTIter " ) ;
2019-01-17 01:40:23 -08:00
if ( ! setDataIter ( & train_iter , " Train " , data_files , batch_size ) ) {
return 1 ;
}
2018-05-09 17:33:49 -07:00
auto val_iter = MXDataIter ( " MNISTIter " ) ;
2019-01-17 01:40:23 -08:00
if ( ! setDataIter ( & val_iter , " Label " , data_files , batch_size ) ) {
return 1 ;
}
2017-03-22 11:55:51 +08:00
2018-11-27 09:48:25 -08:00
// initialize parameters
Xavier xavier = Xavier ( Xavier : : gaussian , Xavier : : in , 2 ) ;
for ( auto & arg : args_map ) {
xavier ( arg . first , & arg . second ) ;
}
Optimizer * opt = OptimizerRegistry : : Find ( " sgd " ) ;
2017-08-04 04:18:07 +08:00
opt - > SetParam ( " lr " , learning_rate )
- > SetParam ( " wd " , weight_decay )
- > SetParam ( " momentum " , 0.9 )
2017-03-22 11:55:51 +08:00
- > SetParam ( " rescale_grad " , 1.0 / batch_size )
- > SetParam ( " clip_gradient " , 10 ) ;
2018-05-09 17:33:49 -07:00
auto * exec = resnet . SimpleBind ( ctx , args_map ) ;
2017-08-04 04:18:07 +08:00
auto arg_names = resnet . ListArguments ( ) ;
2017-03-22 11:55:51 +08:00
2018-11-27 09:48:25 -08:00
// Create metrics
Accuracy train_acc , val_acc ;
2019-03-07 12:53:27 +08:00
LogLoss logloss_train , logloss_val ;
for ( int epoch = 0 ; epoch < max_epoch ; + + epoch ) {
LG < < " Epoch: " < < epoch ;
2017-03-22 11:55:51 +08:00
train_iter . Reset ( ) ;
2018-11-27 09:48:25 -08:00
train_acc . Reset ( ) ;
2019-03-07 12:53:27 +08:00
int iter = 0 ;
2017-03-22 11:55:51 +08:00
while ( train_iter . Next ( ) ) {
auto data_batch = train_iter . GetDataBatch ( ) ;
2019-03-07 12:53:27 +08:00
ResizeInput ( data_batch . data , data_shape ) . CopyTo ( & args_map [ " data " ] ) ;
2017-03-22 11:55:51 +08:00
data_batch . label . CopyTo ( & args_map [ " data_label " ] ) ;
NDArray : : WaitAll ( ) ;
exec - > Forward ( true ) ;
exec - > Backward ( ) ;
2017-08-04 04:18:07 +08:00
for ( size_t i = 0 ; i < arg_names . size ( ) ; + + i ) {
if ( arg_names [ i ] = = " data " | | arg_names [ i ] = = " data_label " ) continue ;
opt - > Update ( i , exec - > arg_arrays [ i ] , exec - > grad_arrays [ i ] ) ;
}
2017-03-22 11:55:51 +08:00
NDArray : : WaitAll ( ) ;
2018-11-27 09:48:25 -08:00
train_acc . Update ( data_batch . label , exec - > outputs [ 0 ] ) ;
2019-03-07 12:53:27 +08:00
logloss_train . Reset ( ) ;
logloss_train . Update ( data_batch . label , exec - > outputs [ 0 ] ) ;
+ + iter ;
LG < < " EPOCH: " < < epoch < < " ITER: " < < iter
< < " Train Accuracy: " < < train_acc . Get ( )
< < " Train Loss: " < < logloss_train . Get ( ) ;
2017-03-22 11:55:51 +08:00
}
2019-03-07 12:53:27 +08:00
LG < < " EPOCH: " < < epoch < < " Train Accuracy: " < < train_acc . Get ( ) ;
2017-03-22 11:55:51 +08:00
val_iter . Reset ( ) ;
2018-11-27 09:48:25 -08:00
val_acc . Reset ( ) ;
2019-03-07 12:53:27 +08:00
iter = 0 ;
2017-03-22 11:55:51 +08:00
while ( val_iter . Next ( ) ) {
auto data_batch = val_iter . GetDataBatch ( ) ;
2019-03-07 12:53:27 +08:00
ResizeInput ( data_batch . data , data_shape ) . CopyTo ( & args_map [ " data " ] ) ;
2017-03-22 11:55:51 +08:00
data_batch . label . CopyTo ( & args_map [ " data_label " ] ) ;
NDArray : : WaitAll ( ) ;
exec - > Forward ( false ) ;
NDArray : : WaitAll ( ) ;
2018-11-27 09:48:25 -08:00
val_acc . Update ( data_batch . label , exec - > outputs [ 0 ] ) ;
2019-03-07 12:53:27 +08:00
LG < < " EPOCH: " < < epoch < < " ITER: " < < iter < < " Val Accuracy: " < < val_acc . Get ( ) ;
+ + iter ;
2017-03-22 11:55:51 +08:00
}
2018-11-27 09:48:25 -08:00
LG < < " Validation Accuracy: " < < val_acc . Get ( ) ;
2017-03-22 11:55:51 +08:00
}
delete exec ;
2019-03-03 12:37:11 +08:00
delete opt ;
2017-03-22 11:55:51 +08:00
MXNotifyShutdown ( ) ;
2019-04-02 16:23:54 -07:00
CATCH
2017-03-22 11:55:51 +08:00
return 0 ;
}