2017-08-08 16:36:23 -07:00
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
2016-04-15 16:12:18 -04:00
# pylint: skip-file
2016-04-15 16:08:19 -04:00
import numpy as np
import mxnet as mx
import numba
import logging
2016-04-15 16:17:09 -04:00
# We use numba.jit to implement the loss gradient.
2016-04-15 16:08:19 -04:00
@numba.jit
def mc_hinge_grad ( scores , labels ) :
2016-04-16 00:13:30 -04:00
scores = scores . asnumpy ( )
2018-01-05 12:00:57 -08:00
labels = labels . asnumpy ( ) . astype ( int )
2016-04-16 00:13:30 -04:00
2016-04-15 16:08:19 -04:00
n , _ = scores . shape
grad = np . zeros_like ( scores )
for i in range ( n ) :
score = 1 + scores [ i ] - scores [ i , labels [ i ] ]
score [ labels [ i ] ] = 0
ind_pred = score . argmax ( )
grad [ i , labels [ i ] ] - = 1
grad [ i , ind_pred ] + = 1
return grad
if __name__ == ' __main__ ' :
2016-04-15 16:17:09 -04:00
n_epoch = 10
2016-04-15 16:08:19 -04:00
batch_size = 100
2016-04-15 16:17:09 -04:00
num_gpu = 2
contexts = mx . context . cpu ( ) if num_gpu < 1 else [ mx . context . gpu ( i ) for i in range ( num_gpu ) ]
2016-04-15 16:08:19 -04:00
# build a MLP module
data = mx . symbol . Variable ( ' data ' )
fc1 = mx . symbol . FullyConnected ( data , name = ' fc1 ' , num_hidden = 128 )
act1 = mx . symbol . Activation ( fc1 , name = ' relu1 ' , act_type = " relu " )
fc2 = mx . symbol . FullyConnected ( act1 , name = ' fc2 ' , num_hidden = 64 )
act2 = mx . symbol . Activation ( fc2 , name = ' relu2 ' , act_type = " relu " )
fc3 = mx . symbol . FullyConnected ( act2 , name = ' fc3 ' , num_hidden = 10 )
mlp = mx . mod . Module ( fc3 , context = contexts )
2016-04-16 00:13:30 -04:00
loss = mx . mod . PythonLossModule ( grad_func = mc_hinge_grad )
2016-04-15 16:08:19 -04:00
2016-04-16 00:13:30 -04:00
mod = mx . mod . SequentialModule ( ) \
. add ( mlp ) \
. add ( loss , take_labels = True , auto_wiring = True )
2016-04-15 16:08:19 -04:00
train_dataiter = mx . io . MNISTIter (
image = " data/train-images-idx3-ubyte " ,
label = " data/train-labels-idx1-ubyte " ,
data_shape = ( 784 , ) ,
batch_size = batch_size , shuffle = True , flat = True , silent = False , seed = 10 )
val_dataiter = mx . io . MNISTIter (
image = " data/t10k-images-idx3-ubyte " ,
label = " data/t10k-labels-idx1-ubyte " ,
data_shape = ( 784 , ) ,
batch_size = batch_size , shuffle = True , flat = True , silent = False )
logging . basicConfig ( level = logging . DEBUG )
mod . fit ( train_dataiter , eval_data = val_dataiter ,
optimizer_params = { ' learning_rate ' : 0.01 , ' momentum ' : 0.9 } ,
num_epoch = n_epoch )