2017-08-08 16:36:23 -07:00
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
2016-04-13 12:54:38 -04:00
# pylint: skip-file
2016-12-20 21:29:32 -08:00
import os , sys
from utils import get_data
2016-04-13 12:54:38 -04:00
import mxnet as mx
import numpy as np
import logging
2016-04-14 18:56:02 -04:00
# whether to demo model-parallelism + data parallelism
2016-10-13 22:54:01 -04:00
demo_data_model_parallelism = True
2016-04-14 18:56:02 -04:00
if demo_data_model_parallelism :
contexts = [ [ mx . context . gpu ( 0 ) , mx . context . gpu ( 1 ) ] , [ mx . context . gpu ( 2 ) , mx . context . gpu ( 3 ) ] ]
else :
contexts = [ mx . context . cpu ( ) , mx . context . cpu ( ) ]
2016-04-13 12:54:38 -04:00
#--------------------------------------------------------------------------------
# module 1
#--------------------------------------------------------------------------------
data = mx . symbol . Variable ( ' data ' )
fc1 = mx . symbol . FullyConnected ( data , name = ' fc1 ' , num_hidden = 128 )
act1 = mx . symbol . Activation ( fc1 , name = ' relu1 ' , act_type = " relu " )
2016-04-14 18:56:02 -04:00
mod1 = mx . mod . Module ( act1 , label_names = [ ] , context = contexts [ 0 ] )
2016-04-13 12:54:38 -04:00
#--------------------------------------------------------------------------------
# module 2
#--------------------------------------------------------------------------------
2016-04-13 14:13:23 -04:00
data = mx . symbol . Variable ( ' data ' )
2016-04-13 12:54:38 -04:00
fc2 = mx . symbol . FullyConnected ( data , name = ' fc2 ' , num_hidden = 64 )
act2 = mx . symbol . Activation ( fc2 , name = ' relu2 ' , act_type = " relu " )
fc3 = mx . symbol . FullyConnected ( act2 , name = ' fc3 ' , num_hidden = 10 )
softmax = mx . symbol . SoftmaxOutput ( fc3 , name = ' softmax ' )
2016-04-14 18:56:02 -04:00
mod2 = mx . mod . Module ( softmax , context = contexts [ 1 ] )
2016-04-13 12:54:38 -04:00
#--------------------------------------------------------------------------------
# Container module
#--------------------------------------------------------------------------------
mod_seq = mx . mod . SequentialModule ( )
2016-04-13 14:13:23 -04:00
mod_seq . add ( mod1 ) . add ( mod2 , take_labels = True , auto_wiring = True )
2016-04-13 12:54:38 -04:00
#--------------------------------------------------------------------------------
# Training
#--------------------------------------------------------------------------------
n_epoch = 2
batch_size = 100
2016-10-13 22:54:01 -04:00
basedir = os . path . dirname ( __file__ )
2016-12-20 21:29:32 -08:00
get_data . get_mnist ( os . path . join ( basedir , " data " ) )
2016-04-13 12:54:38 -04:00
train_dataiter = mx . io . MNISTIter (
2016-12-20 21:29:32 -08:00
image = os . path . join ( basedir , " data " , " train-images-idx3-ubyte " ) ,
label = os . path . join ( basedir , " data " , " train-labels-idx1-ubyte " ) ,
2016-04-13 12:54:38 -04:00
data_shape = ( 784 , ) ,
batch_size = batch_size , shuffle = True , flat = True , silent = False , seed = 10 )
val_dataiter = mx . io . MNISTIter (
2016-12-20 21:29:32 -08:00
image = os . path . join ( basedir , " data " , " t10k-images-idx3-ubyte " ) ,
label = os . path . join ( basedir , " data " , " t10k-labels-idx1-ubyte " ) ,
2016-04-13 12:54:38 -04:00
data_shape = ( 784 , ) ,
batch_size = batch_size , shuffle = True , flat = True , silent = False )
logging . basicConfig ( level = logging . DEBUG )
mod_seq . fit ( train_dataiter , eval_data = val_dataiter ,
optimizer_params = { ' learning_rate ' : 0.01 , ' momentum ' : 0.9 } , num_epoch = n_epoch )