2017-08-11 14:12:47 -07:00
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
2017-03-22 11:55:51 +08:00
/*!
* \file ndarray.hpp
* \brief implementation of the ndarray
* \author Zhang Chen, Chuntao Hong
*/
2017-07-12 10:04:40 -07:00
# ifndef MXNET_CPP_NDARRAY_HPP_
# define MXNET_CPP_NDARRAY_HPP_
2017-03-22 11:55:51 +08:00
2017-05-18 00:57:37 +08:00
# include <algorithm>
2017-03-22 11:55:51 +08:00
# include <map>
# include <string>
# include <vector>
2017-05-18 00:57:37 +08:00
# include <iterator>
2017-03-22 11:55:51 +08:00
# include "dmlc/logging.h"
# include "mxnet-cpp/ndarray.h"
2017-08-30 13:39:24 -04:00
# include "mxnet-cpp/operator.h"
2017-03-22 11:55:51 +08:00
namespace mxnet {
namespace cpp {
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreateNone ( & handle ) , 0 ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const NDArrayHandle & handle ) {
2017-03-22 11:55:51 +08:00
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const std : : vector < mx_uint > & shape , const Context & context ,
2019-05-24 16:44:13 +08:00
bool delay_alloc , int dtype ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
2019-05-24 16:44:13 +08:00
CHECK_EQ ( MXNDArrayCreateEx ( shape . data ( ) , shape . size ( ) , context . GetDeviceType ( ) ,
context . GetDeviceId ( ) , delay_alloc , dtype , & handle ) ,
2017-03-22 11:55:51 +08:00
0 ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2019-05-24 16:44:13 +08:00
inline NDArray : : NDArray ( const Shape & shape , const Context & context ,
bool delay_alloc , int dtype ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
2019-05-24 16:44:13 +08:00
CHECK_EQ ( MXNDArrayCreateEx ( shape . data ( ) , shape . ndim ( ) , context . GetDeviceType ( ) ,
context . GetDeviceId ( ) , delay_alloc , dtype , & handle ) ,
2017-03-22 11:55:51 +08:00
0 ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const mx_float * data , size_t size ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreateNone ( & handle ) , 0 ) ;
MXNDArraySyncCopyFromCPU ( handle , data , size ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const mx_float * data , const Shape & shape ,
const Context & context ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreate ( shape . data ( ) , shape . ndim ( ) , context . GetDeviceType ( ) ,
context . GetDeviceId ( ) , false , & handle ) ,
0 ) ;
Multithreaded Inference Support (#16654)
* Add cached op threadsafe version with corresponding C APIs, CPP Package changes, CI changes and tests
* Fix download cmd in runtime_functions
* Add CI changes
* Add stage
Fix indentation
* Fix lint
* Change to DEFAULT for C API
* Fix mxnet_unit_tests path
* export correct LD_LIBRARY_PATH
* Add cpp include dirs
* Build test with USE_CPP_PACKAGE
* Add cached op threadsafe version with corresponding C APIs, CPP Package changes, CI changes and tests
* Fix download cmd in runtime_functions
* Merge
* change mkldnn lib name
* Add static_alloc, static_Shape support
* Address review comments
* Make GetCachedOpThreadSafeState similar to cached_op
* Address review comments: comments for locking strategy
* multithreaded inference tutorial
* [Estimator] handle composite metrics in estimator (#16676)
* handle composite metrics in estimator
* fix composite metric case in handlers
* remove unused import
* [Estimator] refactor estimator to allow overriding evaluate/fit of a batch (#16678)
* refactor estimator to allow overriding evaluate/fit of a batch
* add doc to explain call structure and how to override
* fix and doc
* Pointwise fusion for GPU (#15167)
* Beginning of RTC of pointwise ops
* Code generation from the given JSON
* add initial simple_partition_pass and use it for pointwise fusion
* fix the fusion, use a symbol.Copy() at the beginning of binding function, use the name of input nodes in the cuda code
* Fixes
* Adding support for attribute inference for backward nodes when fusing
* keep proper input ordering for fused Op
* instantiate the indexed_graph before starting the subgraph replacement, return a new graph to reset the indexed_graph
* Fuse backward
* fix ordering of subgraph node inputs using subgraph topological ordering instead of main graph topological ordering, add tvm.patch
* excluse forward node fusion during the fusion of the nodes in the backward graph
* Dealing with fused backward nodes inferattr
* use subgraph.indexed_graph() instead of main for _FusedOpHelper nodes node_id, invert control_deps loop to modify topology of subgraph before calling its indexed_graph(), check that all node of the first DFSVisit are actually in the subgraph
* Adding support for other reqs in codegen
* Fix
* Cleaning
* Change the TVM submodule
* More cleaning
* Making linter happy
* Do fusion only if default context is GPU
* Fixes for tests
Add powerscalar and rpowerscalar, fix return type of zero and one
Cleaning, fixing lint
Go back to proper TVM submodule
* Fix the TVM commit
* Fix lint
* Guard fusion with MXNET_USE_CUDA
* Fix
* Fix clang-tidy
* Add erf and erfinv backward
* Gluon support for fusion
* Cleaning
* Cleaning and allow shape/type change in FusedOp
* Fixing Gluon bugs
* Fixing after rebase
* Fixing race condition and guarding against races when using NVRTC
* Cleaning and renaming FusedOp to _FusedOp
* Going easy on Windows compiler
* Disable fusion on Windows for now
* Refactor InferAttr and InferShapeAttr
* Added slice and half2 support to FusedOp
* Fix lint errors
* Added multiple types support for vector loading/storing
* add slice fusion when it's at the beginning of subgraphs
* Removed constant ndim assumption in fused op
* Fix memory alignment issue in slice for FusedOp
* Fixes
* Fix lint errors
* Do not include cuda_fp16.h
* Refactor fused op op lists
* Make linter happy
* Changes from review
* Fixes after rebase
* Expand FusedOp support for slice
* Fix for fp16 _zeros and _ones
* Fix
* Moving aux functions to unnamed namespace and detail namespace -> fusion
namespace
* Disabling fusion if it alters topological order of inputs
* Print code only when env variable is set
* Fix
* Fix lint and 2 tests that specify the same names for multiple inputs
* Fixes from review and disabling fusion of slice with non-default step
* Add amp_cast to fusion, fixes
* Add amp_multicast and its backward to the list of support ops
* Apply wording suggestions from code review
Co-Authored-By: Aaron Markham <markhama@amazon.com>
* Apply wording suggestions from code review
Co-Authored-By: Aaron Markham <markhama@amazon.com>
* Make clearer comment
* Adding punctuation and capitalization to \brief descriptions
* Fix
* Fix
* Add backward_cast to fusion
* Adding unittests for fusion. Fix for erfinv_grad
* Adding slice ops and add_n to tests
* Fixes from review
* Setting inplace option
* Fix lint
* Storing double in half
* Retrigger CI
* Slight relaxing of the relative tolerance in the test
* Move the env variable check to the end
* Fix a race condition between InferShape and scheduled Forward
* Fix flakey test_fusion test involving fp32 erfinv op.
* Fix from review
* Added broadcast_like and slice_like to fused op
* Minor fix and cleanup
* Added negative axis support in slice_axis, temporarily disabled fusion of slice_like and broadcast_like
* Added axes support to slice_like
* Added axis support to broadcast_like
* Add fast_load_slice function to fused op code
* Added runtime switch for choosing fast and slow slice kernel
* Fix lint and warning
* Going easy on Windows compiler (again)
* Fix slice_like
* Debug broadcast_like fusion
* Fix lint
* Fix lint
* Trigger CI
* Get rid of the initializer list
* Fix backward calls with different gradient type
* avoid cycle when adding node specific for inputs of subgraph for pointwise fusion
* Fix lint
* Add namespace to the fusion implementations
* Set launch bounds on the fused kernel
* Fix NumPy tests
* Test showcasing an issue fixed in PR #16553
* Cast scalarts to FP32 and perform (a*1.0/b) instead of (a/b)
Fix lint errors
Fix lint
* Fix a bug in cycle detection for inputs only op in pointwise fusion
* Add comments to simple_partition_pass.h file
* fix install dir (#16690)
* [numpy] add numpy operator : append (#16564)
* add operator : append ; fix op concatenate when axis = None
* pylint disable
remove mistake
disable pylint
* Initializer.__eq__ (#16680)
* fix binary dependencies in CD and nightly (#16693)
* [MKL-DNN] Add mxnet mkldnn cmake tutorial (#16688)
* add mxnet mkldnn cmake instruction
* imporve doc
* OMP->OpenMP
* Revert "[MKLDNN]Fix reorder2default (#16602)" (#16697)
This reverts commit dd4eaf5c23046d07a4578a219e2dd3622e5620fa.
* [Estimator] refactor estimator and clarify docs (#16694)
* refactor estimator and clarify docs
* fix info message and test
* clean up after releasing logging handler
* Eliminate common expressions (#15657)
* Eliminate common expressions from a graph
* Guarding against optimizing out stateful ops and ops that require
resource
* Fix lint
* Added THasDeterministicOutput to multiple ops
* DDebug eliminate common expr
* Added test
* Expose get_optimized_symbol
* Fix
* Fix 2
* Add doc to the Python call
* Add env var MXNET_ELIMINATE_COMMON_EXPR, default true
* Add comments, improve readability of eliminate_common_expr_pass.cc
* Expand testing
* Lower priority of THasDeterministicOutput attr for equal Node test
* Change mx.gpu() to mx.cpu() in tests
* Skip CSE test on Windows (as env variable setting during test does not work there)
* Add missing import sys
* Add missing import logging
* Backport of #16711, #16737, #16408 to 1.6 branch (#16763)
* support mixed-precision true_divide (#16711)
* [MKLDNN] use dim_t instead of int in slice/transpose operators (#16737)
* use dim_t instead of int
* fix same issue in pooling
* rebase code
* trigger CI
* Add MXNet Ops for fast multihead attention (#16408)
* add MXNet Ops for fast multihead attention
* add cutlass as 3rdparty dependency
* add cutlass to compilation flags
* remove all cutlass stuff
* add better error message and description and remove cutlass from compilation flags
* change credit for the approach since the code have changed
* fix typos
* correct another typo
* Add all the cuda/cublas helper functions
* remove tests using kAddTo
* only use cublasStridedBatchedGemm if CUDA >= 9.1
* add equivalent mxnet code in description of mha ops
* remove a wrong copy-paste
* add _contrib for namespace and add GPU only on description
* add warning in bwd_ignore_zero_init description, also test with fp32
* add error return if bwd_ignore_zero_init is used without MXNET_EXEC_ENABLE_ADDTO
* remove std::move for clang
* remove bwd_ignore_zero_init flag
* remove bwd_ignore_zero_init in test_operator_gpu.py
* fix typo
* fix another typo
* Removed unrelated test
* Add example and documentation for multi threaded inference
* Add LICENSE
* Add get_model.py
* Add license for README
* Refactor cached op and cached op threadsafe
* Add limitation
* Add tests for naive engine
* Add latest test changes
* Thread Safety tests in NaiveEngine mode
* Thread Safety tests update
* Update thread safety tests, add unsupported use cases
* Changes to doc and refactor
* Fix todo owner, indentation and mx_float->float
* Refactor cached op code, remove num_threads arg from example
* Fix lint
* Fix warning
* Add back cython, required for unix-gpu build
* Fix for windows
* Add bulking support for thread safe cached op version
* Add support for subgraph testing
* import mxnet before calling get_backend_symbol
* Fix symbol json name
* Refactor DynamicForward
* Add comments
* Add DMLC_ATTRIBUTE_UNUSED
* Fix use_naive_run issue
* Fix lint
* Revert unittest_cpp to old test since it doesnt test thread safety
* Fix doc
Co-authored-by: Sheng Zha <szha@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: Tao Lv <tao.a.lv@intel.com>
Co-authored-by: JiangZhaoh <54654391+JiangZhaoh@users.noreply.github.com>
Co-authored-by: Leonard Lausen <leonard@lausen.nl>
Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: Zhennan Qin <zhennan.qin@intel.com>
2020-02-01 09:36:59 -08:00
CHECK_EQ ( MXNDArraySyncCopyFromCPU ( handle , data , shape . Size ( ) ) , 0 ) ;
2017-03-22 11:55:51 +08:00
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const std : : vector < mx_float > & data , const Shape & shape ,
const Context & context ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreate ( shape . data ( ) , shape . ndim ( ) , context . GetDeviceType ( ) ,
context . GetDeviceId ( ) , false , & handle ) ,
0 ) ;
MXNDArraySyncCopyFromCPU ( handle , data . data ( ) , shape . Size ( ) ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const std : : vector < mx_float > & data ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreateNone ( & handle ) , 0 ) ;
MXNDArraySyncCopyFromCPU ( handle , data . data ( ) , data . size ( ) ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator + ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _plus_scalar " ) ( * this , scalar ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator - ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _minus_scalar " ) ( * this , scalar ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator * ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _mul_scalar " ) ( * this , scalar ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator / ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _div_scalar " ) ( * this , scalar ) . Invoke ( ret ) ;
return ret ;
}
2017-06-19 23:59:40 -07:00
inline NDArray NDArray : : operator % ( mx_float scalar ) {
NDArray ret ;
Operator ( " _mod_scalar " ) ( * this , scalar ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator + ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _plus " ) ( * this , rhs ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator - ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _minus " ) ( * this , rhs ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator * ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _mul " ) ( * this , rhs ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator / ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _div " ) ( * this , rhs ) . Invoke ( ret ) ;
return ret ;
}
2017-06-19 23:59:40 -07:00
inline NDArray NDArray : : operator % ( const NDArray & rhs ) {
NDArray ret ;
Operator ( " _mod " ) ( * this , rhs ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator = ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
Operator ( " _set_value " ) ( scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator + = ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
Operator ( " _plus_scalar " ) ( * this , scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator - = ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
Operator ( " _minus_scalar " ) ( * this , scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator * = ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
Operator ( " _mul_scalar " ) ( * this , scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator / = ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
Operator ( " _div_scalar " ) ( * this , scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-06-19 23:59:40 -07:00
inline NDArray & NDArray : : operator % = ( mx_float scalar ) {
Operator ( " _mod_scalar " ) ( * this , scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator + = ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
Operator ( " _plus " ) ( * this , rhs ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator - = ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
Operator ( " _minus " ) ( * this , rhs ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator * = ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
Operator ( " _mul " ) ( * this , rhs ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator / = ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
Operator ( " _div " ) ( * this , rhs ) . Invoke ( * this ) ;
return * this ;
}
2017-06-19 23:59:40 -07:00
inline NDArray & NDArray : : operator % = ( const NDArray & rhs ) {
Operator ( " _mod " ) ( * this , rhs ) . Invoke ( * this ) ;
return * this ;
}
2017-03-22 11:55:51 +08:00
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : ArgmaxChannel ( ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " argmax_channel " ) ( * this ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : SyncCopyFromCPU ( const mx_float * data , size_t size ) {
2017-03-22 11:55:51 +08:00
MXNDArraySyncCopyFromCPU ( blob_ptr_ - > handle_ , data , size ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : SyncCopyFromCPU ( const std : : vector < mx_float > & data ) {
2017-03-22 11:55:51 +08:00
MXNDArraySyncCopyFromCPU ( blob_ptr_ - > handle_ , data . data ( ) , data . size ( ) ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : SyncCopyToCPU ( mx_float * data , size_t size ) {
2017-03-22 11:55:51 +08:00
MXNDArraySyncCopyToCPU ( blob_ptr_ - > handle_ , data , size > 0 ? size : Size ( ) ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : SyncCopyToCPU ( std : : vector < mx_float > * data , size_t size ) {
2017-03-22 11:55:51 +08:00
size = size > 0 ? size : Size ( ) ;
data - > resize ( size ) ;
MXNDArraySyncCopyToCPU ( blob_ptr_ - > handle_ , data - > data ( ) , size ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : Copy ( const Context & ctx ) const {
2019-05-24 16:44:13 +08:00
NDArray ret ( GetShape ( ) , ctx , true , this - > GetDType ( ) ) ;
2017-03-22 11:55:51 +08:00
Operator ( " _copyto " ) ( * this ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : CopyTo ( NDArray * other ) const {
2017-03-22 11:55:51 +08:00
Operator ( " _copyto " ) ( * this ) . Invoke ( * other ) ;
return * other ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : Slice ( mx_uint begin , mx_uint end ) const {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArraySlice ( GetHandle ( ) , begin , end , & handle ) , 0 ) ;
return NDArray ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : Reshape ( const Shape & new_shape ) const {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
std : : vector < int > dims ( new_shape . ndim ( ) ) ;
for ( index_t i = 0 ; i < new_shape . ndim ( ) ; + + i ) {
dims [ i ] = new_shape [ i ] ;
}
new_shape . data ( ) ;
CHECK_EQ (
MXNDArrayReshape ( GetHandle ( ) , new_shape . ndim ( ) , dims . data ( ) , & handle ) , 0 ) ;
return NDArray ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : WaitToRead ( ) const {
2019-04-08 00:21:36 -07:00
CHECK_EQ ( MXNDArrayWaitToRead ( blob_ptr_ - > handle_ ) , 0 ) < < MXGetLastError ( ) ;
2017-03-22 11:55:51 +08:00
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : WaitToWrite ( ) {
2019-04-08 00:21:36 -07:00
CHECK_EQ ( MXNDArrayWaitToWrite ( blob_ptr_ - > handle_ ) , 0 ) < < MXGetLastError ( ) ;
2017-03-22 11:55:51 +08:00
}
2019-04-08 00:21:36 -07:00
inline void NDArray : : WaitAll ( ) { CHECK_EQ ( MXNDArrayWaitAll ( ) , 0 ) < < MXGetLastError ( ) ; }
2017-03-28 23:34:46 -05:00
inline void NDArray : : SampleGaussian ( mx_float mu , mx_float sigma , NDArray * out ) {
2017-09-26 13:39:24 -07:00
Operator ( " _random_normal " ) ( mu , sigma ) . Invoke ( * out ) ;
2017-03-22 11:55:51 +08:00
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : SampleUniform ( mx_float begin , mx_float end , NDArray * out ) {
2017-09-26 13:39:24 -07:00
Operator ( " _random_uniform " ) ( begin , end ) . Invoke ( * out ) ;
2017-03-22 11:55:51 +08:00
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : Load ( const std : : string & file_name ,
std : : vector < NDArray > * array_list ,
std : : map < std : : string , NDArray > * array_map ) {
2017-03-22 11:55:51 +08:00
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoad ( file_name . c_str ( ) , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
if ( array_list ! = nullptr ) {
2018-04-03 20:55:35 +01:00
array_list - > reserve ( out_size ) ;
2017-03-22 11:55:51 +08:00
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_list - > push_back ( NDArray ( out_arr [ i ] ) ) ;
}
}
if ( array_map ! = nullptr & & out_name_size > 0 ) {
CHECK_EQ ( out_name_size , out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
( * array_map ) [ out_names [ i ] ] = NDArray ( out_arr [ i ] ) ;
}
}
}
2017-03-28 23:34:46 -05:00
inline std : : map < std : : string , NDArray > NDArray : : LoadToMap (
2017-03-22 11:55:51 +08:00
const std : : string & file_name ) {
std : : map < std : : string , NDArray > array_map ;
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoad ( file_name . c_str ( ) , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
if ( out_name_size > 0 ) {
CHECK_EQ ( out_name_size , out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_map [ out_names [ i ] ] = NDArray ( out_arr [ i ] ) ;
}
}
return array_map ;
}
2017-03-28 23:34:46 -05:00
inline std : : vector < NDArray > NDArray : : LoadToList ( const std : : string & file_name ) {
2017-03-22 11:55:51 +08:00
std : : vector < NDArray > array_list ;
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoad ( file_name . c_str ( ) , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
2018-04-03 20:55:35 +01:00
array_list . reserve ( out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_list . push_back ( NDArray ( out_arr [ i ] ) ) ;
}
return array_list ;
}
inline void NDArray : : LoadFromBuffer ( const void * buffer , size_t size ,
std : : vector < NDArray > * array_list ,
std : : map < std : : string , NDArray > * array_map ) {
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoadFromBuffer ( buffer , size , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
if ( array_list ! = nullptr ) {
array_list - > reserve ( out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_list - > push_back ( NDArray ( out_arr [ i ] ) ) ;
}
}
if ( array_map ! = nullptr & & out_name_size > 0 ) {
CHECK_EQ ( out_name_size , out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
( * array_map ) [ out_names [ i ] ] = NDArray ( out_arr [ i ] ) ;
}
}
}
inline std : : map < std : : string , NDArray > NDArray : : LoadFromBufferToMap (
const void * buffer , size_t size ) {
std : : map < std : : string , NDArray > array_map ;
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoadFromBuffer ( buffer , size , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
if ( out_name_size > 0 ) {
CHECK_EQ ( out_name_size , out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_map [ out_names [ i ] ] = NDArray ( out_arr [ i ] ) ;
}
}
return array_map ;
}
inline std : : vector < NDArray > NDArray : : LoadFromBufferToList ( const void * buffer , size_t size ) {
std : : vector < NDArray > array_list ;
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoadFromBuffer ( buffer , size , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
array_list . reserve ( out_size ) ;
2017-03-22 11:55:51 +08:00
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_list . push_back ( NDArray ( out_arr [ i ] ) ) ;
}
return array_list ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : Save ( const std : : string & file_name ,
const std : : map < std : : string , NDArray > & array_map ) {
2017-03-22 11:55:51 +08:00
std : : vector < NDArrayHandle > args ;
std : : vector < const char * > keys ;
for ( const auto & t : array_map ) {
args . push_back ( t . second . GetHandle ( ) ) ;
keys . push_back ( t . first . c_str ( ) ) ;
}
CHECK_EQ (
MXNDArraySave ( file_name . c_str ( ) , args . size ( ) , args . data ( ) , keys . data ( ) ) ,
0 ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : Save ( const std : : string & file_name ,
const std : : vector < NDArray > & array_list ) {
2017-03-22 11:55:51 +08:00
std : : vector < NDArrayHandle > args ;
for ( const auto & t : array_list ) {
args . push_back ( t . GetHandle ( ) ) ;
}
CHECK_EQ ( MXNDArraySave ( file_name . c_str ( ) , args . size ( ) , args . data ( ) , nullptr ) ,
0 ) ;
}
2017-03-28 23:34:46 -05:00
inline size_t NDArray : : Offset ( size_t h , size_t w ) const {
2019-06-07 15:04:37 -07:00
auto const shape = GetShape ( ) ;
CHECK_EQ ( shape . size ( ) , 2 ) < < " The NDArray needs to be 2 dimensional. " ;
return ( h * shape [ 1 ] ) + w ;
2017-03-22 11:55:51 +08:00
}
2017-03-28 23:34:46 -05:00
inline size_t NDArray : : Offset ( size_t c , size_t h , size_t w ) const {
2017-03-22 11:55:51 +08:00
auto const shape = GetShape ( ) ;
2019-06-07 15:04:37 -07:00
CHECK_EQ ( shape . size ( ) , 3 ) < < " The NDArray needs to be 3 dimensional. " ;
2017-03-22 11:55:51 +08:00
return h * shape [ 0 ] * shape [ 2 ] + w * shape [ 0 ] + c ;
}
2017-03-28 23:34:46 -05:00
inline mx_float NDArray : : At ( size_t h , size_t w ) const {
2017-03-22 11:55:51 +08:00
return GetData ( ) [ Offset ( h , w ) ] ;
}
2017-03-28 23:34:46 -05:00
inline mx_float NDArray : : At ( size_t c , size_t h , size_t w ) const {
2017-03-22 11:55:51 +08:00
return GetData ( ) [ Offset ( c , h , w ) ] ;
}
2019-06-07 15:04:37 -07:00
inline mx_float NDArray : : At ( size_t index ) const {
auto shape = GetShape ( ) ;
CHECK_EQ ( shape . size ( ) , 1 ) < < " The NDArray needs to be 1 dimensional. " ;
CHECK_LT ( index , shape [ 0 ] ) < < " Specified index is out of range. " ;
return GetData ( ) [ index ] ;
}
2017-03-28 23:34:46 -05:00
inline size_t NDArray : : Size ( ) const {
2017-03-22 11:55:51 +08:00
size_t ret = 1 ;
for ( auto & i : GetShape ( ) ) ret * = i ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline std : : vector < mx_uint > NDArray : : GetShape ( ) const {
2019-04-16 10:00:54 -07:00
const int * out_pdata ;
int out_dim ;
MXNDArrayGetShapeEx ( blob_ptr_ - > handle_ , & out_dim , & out_pdata ) ;
2017-03-22 11:55:51 +08:00
std : : vector < mx_uint > ret ;
2019-04-16 10:00:54 -07:00
for ( int i = 0 ; i < out_dim ; + + i ) {
2017-03-22 11:55:51 +08:00
ret . push_back ( out_pdata [ i ] ) ;
}
return ret ;
}
2017-04-01 07:13:44 +09:00
inline int NDArray : : GetDType ( ) const {
int ret ;
2017-04-04 13:51:57 -04:00
MXNDArrayGetDType ( blob_ptr_ - > handle_ , & ret ) ;
2017-04-01 07:13:44 +09:00
return ret ;
}
2017-03-28 23:34:46 -05:00
inline const mx_float * NDArray : : GetData ( ) const {
2017-04-01 07:13:44 +09:00
void * ret ;
2017-03-22 11:55:51 +08:00
MXNDArrayGetData ( blob_ptr_ - > handle_ , & ret ) ;
2017-04-01 07:13:44 +09:00
if ( GetDType ( ) ! = 0 ) {
2020-02-09 02:50:49 +01:00
return nullptr ;
2017-04-01 07:13:44 +09:00
}
return static_cast < mx_float * > ( ret ) ;
2017-03-22 11:55:51 +08:00
}
2017-04-01 07:13:44 +09:00
2017-03-28 23:34:46 -05:00
inline Context NDArray : : GetContext ( ) const {
2017-03-22 11:55:51 +08:00
int out_dev_type ;
int out_dev_id ;
MXNDArrayGetContext ( blob_ptr_ - > handle_ , & out_dev_type , & out_dev_id ) ;
return Context ( ( DeviceType ) out_dev_type , out_dev_id ) ;
}
2017-05-18 00:57:37 +08:00
inline std : : ostream & operator < < ( std : : ostream & out , const NDArray & ndarray ) {
// TODO(lx75249): Consider DType / beautify like numpy
auto shape = ndarray . GetShape ( ) ;
NDArray cpu_array ( ndarray . GetShape ( ) , Context : : cpu ( ) ) ;
if ( ndarray . GetContext ( ) . GetDeviceType ( ) ! = DeviceType : : kGPU ) {
cpu_array = ndarray ;
} else {
ndarray . WaitToRead ( ) ;
ndarray . CopyTo ( & cpu_array ) ;
}
out < < ' [ ' ;
cpu_array . WaitToRead ( ) ;
std : : copy ( cpu_array . GetData ( ) , cpu_array . GetData ( ) + ndarray . Size ( ) ,
std : : ostream_iterator < float > ( out , " , " ) ) ;
out < < ' ] ' ;
return out ;
}
2017-03-22 11:55:51 +08:00
} // namespace cpp
} // namespace mxnet
2017-07-12 10:04:40 -07:00
# endif // MXNET_CPP_NDARRAY_HPP_