# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # coding: utf-8 # pylint: disable=invalid-name, too-many-arguments """Lightweight API for mxnet prediction. This is for prediction only, use mxnet python package instead for most tasks. """ from __future__ import absolute_import import os import sys from array import array import ctypes import logging import numpy as np # pylint: disable= no-member _DTYPE_NP_TO_MX = { None: -1, np.float32: 0, np.float64: 1, np.float16: 2, np.uint8: 3, np.int32: 4, np.int8: 5, np.int64: 6, } _DTYPE_MX_TO_NP = { -1: None, 0: np.float32, 1: np.float64, 2: np.float16, 3: np.uint8, 4: np.int32, 5: np.int8, 6: np.int64, } __all__ = ["Predictor", "load_ndarray_file"] if sys.version_info[0] == 3: py_str = lambda x: x.decode('utf-8') def c_str_array(strings): """Create ctypes const char ** from a list of Python strings. Parameters ---------- strings : list of string Python strings. Returns ------- (ctypes.c_char_p * len(strings)) A const char ** pointer that can be passed to C API. """ arr = (ctypes.c_char_p * len(strings))() arr[:] = [s.encode('utf-8') for s in strings] return arr else: py_str = lambda x: x def c_str_array(strings): """Create ctypes const char ** from a list of Python strings. Parameters ---------- strings : list of strings Python strings. Returns ------- (ctypes.c_char_p * len(strings)) A const char ** pointer that can be passed to C API. """ arr = (ctypes.c_char_p * len(strings))() arr[:] = strings return arr def c_str(string): """"Convert a python string to C string.""" if not isinstance(string, str): string = string.decode('ascii') return ctypes.c_char_p(string.encode('utf-8')) def c_array(ctype, values): """Create ctypes array from a python array.""" return (ctype * len(values))(*values) def c_array_buf(ctype, buf): """Create ctypes array from a Python buffer.""" return (ctype * len(buf)).from_buffer(buf) def _find_lib_path(): """Find mxnet library.""" curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) amalgamation_lib_path = os.path.join(curr_path, '../../lib/libmxnet_predict.so') if os.path.exists(amalgamation_lib_path) and os.path.isfile(amalgamation_lib_path): lib_path = [amalgamation_lib_path] return lib_path else: logging.info('Cannot find libmxnet_predict.so. Will search for MXNet library using libinfo.py then.') try: from mxnet.libinfo import find_lib_path lib_path = find_lib_path() return lib_path except ImportError: libinfo_path = os.path.join(curr_path, '../../python/mxnet/libinfo.py') if os.path.exists(libinfo_path) and os.path.isfile(libinfo_path): libinfo = {'__file__': libinfo_path} exec(compile(open(libinfo_path, "rb").read(), libinfo_path, 'exec'), libinfo, libinfo) lib_path = libinfo['find_lib_path']() return lib_path else: raise RuntimeError('Cannot find libinfo.py at %s.' % libinfo_path) def _load_lib(): """Load libary by searching possible path.""" lib_path = _find_lib_path() lib = ctypes.cdll.LoadLibrary(lib_path[0]) # DMatrix functions lib.MXGetLastError.restype = ctypes.c_char_p return lib def _check_call(ret): """Check the return value of API.""" if ret != 0: raise RuntimeError(py_str(_LIB.MXGetLastError())) def _monitor_callback_wrapper(callback): """A wrapper for the user-defined handle.""" def callback_handle(name, array, _): """ ctypes function """ callback(name, array) return callback_handle _LIB = _load_lib() # type definitions mx_uint = ctypes.c_uint mx_int = ctypes.c_int mx_float = ctypes.c_float mx_float_p = ctypes.POINTER(mx_float) PredictorHandle = ctypes.c_void_p NDListHandle = ctypes.c_void_p devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3} class Predictor(object): """A predictor class that runs prediction. Parameters ---------- symbol_json_str : str Path to the symbol file. param_raw_bytes : str, bytes The raw parameter bytes. input_shapes : dict of str to tuple The shape of input data dev_type : str, optional The device type of the predictor. dev_id : int, optional The device id of the predictor. type_dict : Dict of str->numpy.dtype Input type dictionary, name->dtype """ def __init__(self, symbol_file, param_raw_bytes, input_shapes, dev_type="cpu", dev_id=0, type_dict=None): dev_type = devstr2type[dev_type] indptr = [0] sdata = [] keys = [] for k, v in input_shapes.items(): if not isinstance(v, tuple): raise ValueError("Expect input_shapes to be dict str->tuple") keys.append(c_str(k)) sdata.extend(v) indptr.append(len(sdata)) handle = PredictorHandle() param_raw_bytes = bytearray(param_raw_bytes) ptr = (ctypes.c_char * len(param_raw_bytes)).from_buffer(param_raw_bytes) # data types num_provided_arg_types = 0 # provided type argument names provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided types provided_arg_type_data = ctypes.POINTER(mx_uint)() if type_dict is not None: provided_arg_type_names = [] provided_arg_type_data = [] for k, v in type_dict.items(): v = np.dtype(v).type if v in _DTYPE_NP_TO_MX: provided_arg_type_names.append(k) provided_arg_type_data.append(_DTYPE_NP_TO_MX[v]) num_provided_arg_types = mx_uint(len(provided_arg_type_names)) provided_arg_type_names = c_str_array(provided_arg_type_names) provided_arg_type_data = c_array_buf(ctypes.c_int, array('i', provided_arg_type_data)) _check_call(_LIB.MXPredCreateEx( c_str(symbol_file), ptr, len(param_raw_bytes), ctypes.c_int(dev_type), ctypes.c_int(dev_id), mx_uint(len(indptr) - 1), c_array(ctypes.c_char_p, keys), c_array(mx_uint, indptr), c_array(mx_uint, sdata), num_provided_arg_types, provided_arg_type_names, provided_arg_type_data, ctypes.byref(handle))) self.type_dict = type_dict self.handle = handle def __del__(self): _check_call(_LIB.MXPredFree(self.handle)) def forward(self, **kwargs): """Perform forward to get the output. Parameters ---------- **kwargs Keyword arguments of input variable name to data. Examples -------- >>> predictor.forward(data=mydata) >>> out = predictor.get_output(0) """ if self.type_dict and len(self.type_dict) != len(kwargs.items()): raise ValueError("number of kwargs should be same as len of type_dict" \ "Please check your forward pass inputs" \ "or type_dict passed to Predictor instantiation") for k, v in kwargs.items(): if not isinstance(v, np.ndarray): raise ValueError("Expect numpy ndarray as input") if self.type_dict and k in self.type_dict: v = np.asarray(v, dtype=self.type_dict[k], order='C') else: v = np.asarray(v, dtype=np.float32, order='C') _check_call(_LIB.MXPredSetInput( self.handle, c_str(k), v.ctypes.data_as(mx_float_p), mx_uint(v.size))) _check_call(_LIB.MXPredForward(self.handle)) def reshape(self, input_shapes): """Change the input shape of the predictor. Parameters ---------- input_shapes : dict of str to tuple The new shape of input data. Examples -------- >>> predictor.reshape({'data':data_shape_tuple}) """ indptr = [0] sdata = [] keys = [] for k, v in input_shapes.items(): if not isinstance(v, tuple): raise ValueError("Expect input_shapes to be dict str->tuple") keys.append(c_str(k)) sdata.extend(v) indptr.append(len(sdata)) new_handle = PredictorHandle() _check_call(_LIB.MXPredReshape( mx_uint(len(indptr) - 1), c_array(ctypes.c_char_p, keys), c_array(mx_uint, indptr), c_array(mx_uint, sdata), self.handle, ctypes.byref(new_handle))) _check_call(_LIB.MXPredFree(self.handle)) self.handle = new_handle def get_output(self, index): """Get the index-th output. Parameters ---------- index : int The index of output. Returns ------- out : numpy array. The output array. """ pdata = ctypes.POINTER(mx_uint)() ndim = mx_uint() out_type = mx_int() _check_call(_LIB.MXPredGetOutputShape( self.handle, index, ctypes.byref(pdata), ctypes.byref(ndim))) _check_call(_LIB.MXPredGetOutputType( self.handle, index, ctypes.byref(out_type))) shape = tuple(pdata[:ndim.value]) data = np.empty(shape, dtype=_DTYPE_MX_TO_NP[out_type.value]) _check_call(_LIB.MXPredGetOutput( self.handle, mx_uint(index), data.ctypes.data_as(mx_float_p), mx_uint(data.size))) return data def set_monitor_callback(self, callback, monitor_all=False): cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_void_p, ctypes.c_void_p) self._monitor_callback = cb_type(_monitor_callback_wrapper(callback)) _check_call(_LIB.MXPredSetMonitorCallback(self.handle, self._monitor_callback, None, ctypes.c_int(monitor_all))) def load_ndarray_file(nd_bytes): """Load ndarray file and return as list of numpy array. Parameters ---------- nd_bytes : str or bytes The internal ndarray bytes Returns ------- out : dict of str to numpy array or list of numpy array The output list or dict, depending on whether the saved type is list or dict. """ handle = NDListHandle() olen = mx_uint() nd_bytes = bytearray(nd_bytes) ptr = (ctypes.c_char * len(nd_bytes)).from_buffer(nd_bytes) _check_call(_LIB.MXNDListCreate( ptr, len(nd_bytes), ctypes.byref(handle), ctypes.byref(olen))) keys = [] arrs = [] for i in range(olen.value): key = ctypes.c_char_p() cptr = mx_float_p() pdata = ctypes.POINTER(mx_uint)() ndim = mx_uint() _check_call(_LIB.MXNDListGet( handle, mx_uint(i), ctypes.byref(key), ctypes.byref(cptr), ctypes.byref(pdata), ctypes.byref(ndim))) shape = tuple(pdata[:ndim.value]) dbuffer = (mx_float * np.prod(shape)).from_address(ctypes.addressof(cptr.contents)) ret = np.frombuffer(dbuffer, dtype=np.float32).reshape(shape) ret = np.array(ret, dtype=np.float32) keys.append(py_str(key.value)) arrs.append(ret) _check_call(_LIB.MXNDListFree(handle)) if len(keys) == 0 or len(keys[0]) == 0: return arrs else: return {keys[i] : arrs[i] for i in range(len(keys)) }