# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import sys, os curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.append("../../../amalgamation/python/") from mxnet_predict import Predictor, load_ndarray_file import logging import numpy as np from skimage import io, transform # Load the pre-trained model prefix = "resnet/resnet-18" num_round = 0 symbol_file = "%s-symbol.json" % prefix param_file = "%s-0000.params" % prefix predictor = Predictor(open(symbol_file, "r").read(), open(param_file, "rb").read(), {'data':(1, 3, 224, 224)}) synset = [l.strip() for l in open('resnet/synset.txt').readlines()] def PreprocessImage(path, show_img=False): # load image img = io.imread(path) print("Original Image Shape: ", img.shape) # we crop image from center short_egde = min(img.shape[:2]) yy = int((img.shape[0] - short_egde) / 2) xx = int((img.shape[1] - short_egde) / 2) crop_img = img[yy : yy + short_egde, xx : xx + short_egde] # resize to 224, 224 resized_img = transform.resize(crop_img, (224, 224)) # convert to numpy.ndarray sample = np.asarray(resized_img) * 255 # swap axes to make image from (224, 224, 3) to (3, 224, 224) sample = np.swapaxes(sample, 0, 2) sample = np.swapaxes(sample, 1, 2) # sub mean return sample # Get preprocessed batch (single image batch) batch = PreprocessImage('./download.jpg', True) predictor.forward(data=batch) prob = predictor.get_output(0)[0] pred = np.argsort(prob)[::-1] # Get top1 label top1 = synset[pred[0]] print("Top1: ", top1) # Get top5 label top5 = [synset[pred[i]] for i in range(5)] print("Top5: ", top5)