2017-08-08 16:36:23 -07:00
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
2015-10-23 23:55:35 -07:00
import sys , os
curr_path = os . path . dirname ( os . path . abspath ( os . path . expanduser ( __file__ ) ) )
2015-11-11 19:32:03 +08:00
sys . path . append ( " ../../../amalgamation/python/ " )
2015-10-23 23:55:35 -07:00
from mxnet_predict import Predictor , load_ndarray_file
import logging
import numpy as np
from skimage import io , transform
# Load the pre-trained model
2017-01-11 10:56:35 -08:00
prefix = " resnet/resnet-18 "
num_round = 0
2015-10-23 23:55:35 -07:00
symbol_file = " %s -symbol.json " % prefix
2017-01-11 10:56:35 -08:00
param_file = " %s -0000.params " % prefix
predictor = Predictor ( open ( symbol_file , " r " ) . read ( ) ,
open ( param_file , " rb " ) . read ( ) ,
2015-10-28 21:07:31 -07:00
{ ' data ' : ( 1 , 3 , 224 , 224 ) } )
2015-10-23 23:55:35 -07:00
2017-01-11 10:56:35 -08:00
synset = [ l . strip ( ) for l in open ( ' resnet/synset.txt ' ) . readlines ( ) ]
2015-10-23 23:55:35 -07:00
def PreprocessImage ( path , show_img = False ) :
# load image
img = io . imread ( path )
print ( " Original Image Shape: " , img . shape )
# we crop image from center
short_egde = min ( img . shape [ : 2 ] )
yy = int ( ( img . shape [ 0 ] - short_egde ) / 2 )
xx = int ( ( img . shape [ 1 ] - short_egde ) / 2 )
crop_img = img [ yy : yy + short_egde , xx : xx + short_egde ]
# resize to 224, 224
resized_img = transform . resize ( crop_img , ( 224 , 224 ) )
# convert to numpy.ndarray
2015-11-04 11:12:39 +08:00
sample = np . asarray ( resized_img ) * 255
# swap axes to make image from (224, 224, 3) to (3, 224, 224)
2015-10-23 23:55:35 -07:00
sample = np . swapaxes ( sample , 0 , 2 )
sample = np . swapaxes ( sample , 1 , 2 )
# sub mean
2017-01-11 10:56:35 -08:00
return sample
2015-10-23 23:55:35 -07:00
# Get preprocessed batch (single image batch)
2017-01-11 10:56:35 -08:00
batch = PreprocessImage ( ' ./download.jpg ' , True )
2015-10-23 23:55:35 -07:00
predictor . forward ( data = batch )
prob = predictor . get_output ( 0 ) [ 0 ]
pred = np . argsort ( prob ) [ : : - 1 ]
# Get top1 label
top1 = synset [ pred [ 0 ] ]
print ( " Top1: " , top1 )
# Get top5 label
top5 = [ synset [ pred [ i ] ] for i in range ( 5 ) ]
print ( " Top5: " , top5 )