2017-08-11 14:12:47 -07:00
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
2017-04-26 00:57:01 +08:00
/*!
* Xin Li yakumolx@gmail.com
*/
# include <chrono>
2018-05-11 11:51:53 -07:00
# include "utils.h"
2017-04-26 00:57:01 +08:00
# include "mxnet-cpp/MxNetCpp.h"
using namespace mxnet : : cpp ;
2018-05-09 17:33:49 -07:00
Symbol mlp ( const std : : vector < int > & layers ) {
2017-04-26 00:57:01 +08:00
auto x = Symbol : : Variable ( " X " ) ;
auto label = Symbol : : Variable ( " label " ) ;
2018-05-09 17:33:49 -07:00
std : : vector < Symbol > weights ( layers . size ( ) ) ;
std : : vector < Symbol > biases ( layers . size ( ) ) ;
std : : vector < Symbol > outputs ( layers . size ( ) ) ;
2017-04-26 00:57:01 +08:00
for ( size_t i = 0 ; i < layers . size ( ) ; + + i ) {
2018-05-09 17:33:49 -07:00
weights [ i ] = Symbol : : Variable ( " w " + std : : to_string ( i ) ) ;
biases [ i ] = Symbol : : Variable ( " b " + std : : to_string ( i ) ) ;
2017-04-26 00:57:01 +08:00
Symbol fc = FullyConnected (
i = = 0 ? x : outputs [ i - 1 ] , // data
weights [ i ] ,
biases [ i ] ,
layers [ i ] ) ;
2017-05-19 23:49:41 +08:00
outputs [ i ] = i = = layers . size ( ) - 1 ? fc : Activation ( fc , ActivationActType : : kRelu ) ;
2017-04-26 00:57:01 +08:00
}
return SoftmaxOutput ( outputs . back ( ) , label ) ;
}
int main ( int argc , char * * argv ) {
const int image_size = 28 ;
2018-05-09 17:33:49 -07:00
const std : : vector < int > layers { 128 , 64 , 10 } ;
2017-04-26 00:57:01 +08:00
const int batch_size = 100 ;
const int max_epoch = 10 ;
const float learning_rate = 0.1 ;
const float weight_decay = 1e-2 ;
2018-05-11 11:51:53 -07:00
std : : vector < std : : string > data_files = { " ./data/mnist_data/train-images-idx3-ubyte " ,
" ./data/mnist_data/train-labels-idx1-ubyte " ,
" ./data/mnist_data/t10k-images-idx3-ubyte " ,
" ./data/mnist_data/t10k-labels-idx1-ubyte "
} ;
auto train_iter = MXDataIter ( " MNISTIter " ) ;
2019-01-17 01:40:23 -08:00
if ( ! setDataIter ( & train_iter , " Train " , data_files , batch_size ) ) {
return 1 ;
}
2018-05-11 11:51:53 -07:00
auto val_iter = MXDataIter ( " MNISTIter " ) ;
2019-01-17 01:40:23 -08:00
if ( ! setDataIter ( & val_iter , " Label " , data_files , batch_size ) ) {
return 1 ;
}
2017-04-26 00:57:01 +08:00
2019-04-02 16:23:54 -07:00
TRY
2017-04-26 00:57:01 +08:00
auto net = mlp ( layers ) ;
Context ctx = Context : : cpu ( ) ; // Use CPU for training
2018-05-09 17:33:49 -07:00
std : : map < std : : string , NDArray > args ;
2017-04-26 00:57:01 +08:00
args [ " X " ] = NDArray ( Shape ( batch_size , image_size * image_size ) , ctx ) ;
args [ " label " ] = NDArray ( Shape ( batch_size ) , ctx ) ;
// Let MXNet infer shapes other parameters such as weights
net . InferArgsMap ( ctx , & args , args ) ;
// Initialize all parameters with uniform distribution U(-0.01, 0.01)
auto initializer = Uniform ( 0.01 ) ;
for ( auto & arg : args ) {
// arg.first is parameter name, and arg.second is the value
initializer ( arg . first , & arg . second ) ;
}
// Create sgd optimizer
Optimizer * opt = OptimizerRegistry : : Find ( " sgd " ) ;
2017-08-04 04:18:07 +08:00
opt - > SetParam ( " rescale_grad " , 1.0 / batch_size )
- > SetParam ( " lr " , learning_rate )
- > SetParam ( " wd " , weight_decay ) ;
// Create executor by binding parameters to the model
auto * exec = net . SimpleBind ( ctx , args ) ;
auto arg_names = net . ListArguments ( ) ;
2017-04-26 00:57:01 +08:00
// Start training
for ( int iter = 0 ; iter < max_epoch ; + + iter ) {
int samples = 0 ;
train_iter . Reset ( ) ;
2018-05-09 17:33:49 -07:00
auto tic = std : : chrono : : system_clock : : now ( ) ;
2017-04-26 00:57:01 +08:00
while ( train_iter . Next ( ) ) {
samples + = batch_size ;
auto data_batch = train_iter . GetDataBatch ( ) ;
// Set data and label
2017-11-07 02:13:07 +08:00
data_batch . data . CopyTo ( & args [ " X " ] ) ;
data_batch . label . CopyTo ( & args [ " label " ] ) ;
2017-04-26 00:57:01 +08:00
// Compute gradients
exec - > Forward ( true ) ;
exec - > Backward ( ) ;
// Update parameters
2017-08-04 04:18:07 +08:00
for ( size_t i = 0 ; i < arg_names . size ( ) ; + + i ) {
if ( arg_names [ i ] = = " X " | | arg_names [ i ] = = " label " ) continue ;
opt - > Update ( i , exec - > arg_arrays [ i ] , exec - > grad_arrays [ i ] ) ;
}
2017-04-26 00:57:01 +08:00
}
2018-05-09 17:33:49 -07:00
auto toc = std : : chrono : : system_clock : : now ( ) ;
2017-04-26 00:57:01 +08:00
Accuracy acc ;
val_iter . Reset ( ) ;
while ( val_iter . Next ( ) ) {
auto data_batch = val_iter . GetDataBatch ( ) ;
2017-11-07 02:13:07 +08:00
data_batch . data . CopyTo ( & args [ " X " ] ) ;
data_batch . label . CopyTo ( & args [ " label " ] ) ;
2017-04-26 00:57:01 +08:00
// Forward pass is enough as no gradient is needed when evaluating
exec - > Forward ( false ) ;
acc . Update ( data_batch . label , exec - > outputs [ 0 ] ) ;
}
2018-05-09 17:33:49 -07:00
float duration = std : : chrono : : duration_cast < std : : chrono : : milliseconds >
( toc - tic ) . count ( ) / 1000.0 ;
2017-04-26 00:57:01 +08:00
LG < < " Epoch: " < < iter < < " " < < samples / duration < < " samples/sec Accuracy: " < < acc . Get ( ) ;
}
2017-08-04 04:18:07 +08:00
delete exec ;
2019-03-03 12:37:11 +08:00
delete opt ;
2017-04-26 00:57:01 +08:00
MXNotifyShutdown ( ) ;
2019-04-02 16:23:54 -07:00
CATCH
2017-04-26 00:57:01 +08:00
return 0 ;
}