2017-08-11 14:12:47 -07:00
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
2017-03-22 11:55:51 +08:00
/*!
* \file ndarray.hpp
* \brief implementation of the ndarray
* \author Zhang Chen, Chuntao Hong
*/
2017-07-12 10:04:40 -07:00
# ifndef MXNET_CPP_NDARRAY_HPP_
# define MXNET_CPP_NDARRAY_HPP_
2017-03-22 11:55:51 +08:00
2017-05-18 00:57:37 +08:00
# include <algorithm>
2017-03-22 11:55:51 +08:00
# include <map>
# include <string>
# include <vector>
2017-05-18 00:57:37 +08:00
# include <iterator>
2017-03-22 11:55:51 +08:00
# include "dmlc/logging.h"
# include "mxnet-cpp/ndarray.h"
2017-08-30 13:39:24 -04:00
# include "mxnet-cpp/operator.h"
2017-03-22 11:55:51 +08:00
namespace mxnet {
namespace cpp {
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreateNone ( & handle ) , 0 ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const NDArrayHandle & handle ) {
2017-03-22 11:55:51 +08:00
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const std : : vector < mx_uint > & shape , const Context & context ,
bool delay_alloc ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreate ( shape . data ( ) , shape . size ( ) , context . GetDeviceType ( ) ,
context . GetDeviceId ( ) , delay_alloc , & handle ) ,
0 ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const Shape & shape , const Context & context , bool delay_alloc ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreate ( shape . data ( ) , shape . ndim ( ) , context . GetDeviceType ( ) ,
context . GetDeviceId ( ) , delay_alloc , & handle ) ,
0 ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const mx_float * data , size_t size ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreateNone ( & handle ) , 0 ) ;
MXNDArraySyncCopyFromCPU ( handle , data , size ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const mx_float * data , const Shape & shape ,
const Context & context ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreate ( shape . data ( ) , shape . ndim ( ) , context . GetDeviceType ( ) ,
context . GetDeviceId ( ) , false , & handle ) ,
0 ) ;
MXNDArraySyncCopyFromCPU ( handle , data , shape . Size ( ) ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const std : : vector < mx_float > & data , const Shape & shape ,
const Context & context ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreate ( shape . data ( ) , shape . ndim ( ) , context . GetDeviceType ( ) ,
context . GetDeviceId ( ) , false , & handle ) ,
0 ) ;
MXNDArraySyncCopyFromCPU ( handle , data . data ( ) , shape . Size ( ) ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray : : NDArray ( const std : : vector < mx_float > & data ) {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArrayCreateNone ( & handle ) , 0 ) ;
MXNDArraySyncCopyFromCPU ( handle , data . data ( ) , data . size ( ) ) ;
blob_ptr_ = std : : make_shared < NDBlob > ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator + ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _plus_scalar " ) ( * this , scalar ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator - ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _minus_scalar " ) ( * this , scalar ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator * ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _mul_scalar " ) ( * this , scalar ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator / ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _div_scalar " ) ( * this , scalar ) . Invoke ( ret ) ;
return ret ;
}
2017-06-19 23:59:40 -07:00
inline NDArray NDArray : : operator % ( mx_float scalar ) {
NDArray ret ;
Operator ( " _mod_scalar " ) ( * this , scalar ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator + ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _plus " ) ( * this , rhs ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator - ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _minus " ) ( * this , rhs ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator * ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _mul " ) ( * this , rhs ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : operator / ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " _div " ) ( * this , rhs ) . Invoke ( ret ) ;
return ret ;
}
2017-06-19 23:59:40 -07:00
inline NDArray NDArray : : operator % ( const NDArray & rhs ) {
NDArray ret ;
Operator ( " _mod " ) ( * this , rhs ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator = ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
Operator ( " _set_value " ) ( scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator + = ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
Operator ( " _plus_scalar " ) ( * this , scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator - = ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
Operator ( " _minus_scalar " ) ( * this , scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator * = ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
Operator ( " _mul_scalar " ) ( * this , scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator / = ( mx_float scalar ) {
2017-03-22 11:55:51 +08:00
Operator ( " _div_scalar " ) ( * this , scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-06-19 23:59:40 -07:00
inline NDArray & NDArray : : operator % = ( mx_float scalar ) {
Operator ( " _mod_scalar " ) ( * this , scalar ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator + = ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
Operator ( " _plus " ) ( * this , rhs ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator - = ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
Operator ( " _minus " ) ( * this , rhs ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator * = ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
Operator ( " _mul " ) ( * this , rhs ) . Invoke ( * this ) ;
return * this ;
}
2017-03-28 23:34:46 -05:00
inline NDArray & NDArray : : operator / = ( const NDArray & rhs ) {
2017-03-22 11:55:51 +08:00
Operator ( " _div " ) ( * this , rhs ) . Invoke ( * this ) ;
return * this ;
}
2017-06-19 23:59:40 -07:00
inline NDArray & NDArray : : operator % = ( const NDArray & rhs ) {
Operator ( " _mod " ) ( * this , rhs ) . Invoke ( * this ) ;
return * this ;
}
2017-03-22 11:55:51 +08:00
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : ArgmaxChannel ( ) {
2017-03-22 11:55:51 +08:00
NDArray ret ;
Operator ( " argmax_channel " ) ( * this ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : SyncCopyFromCPU ( const mx_float * data , size_t size ) {
2017-03-22 11:55:51 +08:00
MXNDArraySyncCopyFromCPU ( blob_ptr_ - > handle_ , data , size ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : SyncCopyFromCPU ( const std : : vector < mx_float > & data ) {
2017-03-22 11:55:51 +08:00
MXNDArraySyncCopyFromCPU ( blob_ptr_ - > handle_ , data . data ( ) , data . size ( ) ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : SyncCopyToCPU ( mx_float * data , size_t size ) {
2017-03-22 11:55:51 +08:00
MXNDArraySyncCopyToCPU ( blob_ptr_ - > handle_ , data , size > 0 ? size : Size ( ) ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : SyncCopyToCPU ( std : : vector < mx_float > * data , size_t size ) {
2017-03-22 11:55:51 +08:00
size = size > 0 ? size : Size ( ) ;
data - > resize ( size ) ;
MXNDArraySyncCopyToCPU ( blob_ptr_ - > handle_ , data - > data ( ) , size ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : Copy ( const Context & ctx ) const {
2017-03-22 11:55:51 +08:00
NDArray ret ( GetShape ( ) , ctx ) ;
Operator ( " _copyto " ) ( * this ) . Invoke ( ret ) ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : CopyTo ( NDArray * other ) const {
2017-03-22 11:55:51 +08:00
Operator ( " _copyto " ) ( * this ) . Invoke ( * other ) ;
return * other ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : Slice ( mx_uint begin , mx_uint end ) const {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
CHECK_EQ ( MXNDArraySlice ( GetHandle ( ) , begin , end , & handle ) , 0 ) ;
return NDArray ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline NDArray NDArray : : Reshape ( const Shape & new_shape ) const {
2017-03-22 11:55:51 +08:00
NDArrayHandle handle ;
std : : vector < int > dims ( new_shape . ndim ( ) ) ;
for ( index_t i = 0 ; i < new_shape . ndim ( ) ; + + i ) {
dims [ i ] = new_shape [ i ] ;
}
new_shape . data ( ) ;
CHECK_EQ (
MXNDArrayReshape ( GetHandle ( ) , new_shape . ndim ( ) , dims . data ( ) , & handle ) , 0 ) ;
return NDArray ( handle ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : WaitToRead ( ) const {
2017-03-22 11:55:51 +08:00
CHECK_EQ ( MXNDArrayWaitToRead ( blob_ptr_ - > handle_ ) , 0 ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : WaitToWrite ( ) {
2017-03-22 11:55:51 +08:00
CHECK_EQ ( MXNDArrayWaitToWrite ( blob_ptr_ - > handle_ ) , 0 ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : WaitAll ( ) { CHECK_EQ ( MXNDArrayWaitAll ( ) , 0 ) ; }
inline void NDArray : : SampleGaussian ( mx_float mu , mx_float sigma , NDArray * out ) {
2017-09-26 13:39:24 -07:00
Operator ( " _random_normal " ) ( mu , sigma ) . Invoke ( * out ) ;
2017-03-22 11:55:51 +08:00
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : SampleUniform ( mx_float begin , mx_float end , NDArray * out ) {
2017-09-26 13:39:24 -07:00
Operator ( " _random_uniform " ) ( begin , end ) . Invoke ( * out ) ;
2017-03-22 11:55:51 +08:00
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : Load ( const std : : string & file_name ,
std : : vector < NDArray > * array_list ,
std : : map < std : : string , NDArray > * array_map ) {
2017-03-22 11:55:51 +08:00
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoad ( file_name . c_str ( ) , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
if ( array_list ! = nullptr ) {
2018-04-03 20:55:35 +01:00
array_list - > reserve ( out_size ) ;
2017-03-22 11:55:51 +08:00
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_list - > push_back ( NDArray ( out_arr [ i ] ) ) ;
}
}
if ( array_map ! = nullptr & & out_name_size > 0 ) {
CHECK_EQ ( out_name_size , out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
( * array_map ) [ out_names [ i ] ] = NDArray ( out_arr [ i ] ) ;
}
}
}
2017-03-28 23:34:46 -05:00
inline std : : map < std : : string , NDArray > NDArray : : LoadToMap (
2017-03-22 11:55:51 +08:00
const std : : string & file_name ) {
std : : map < std : : string , NDArray > array_map ;
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoad ( file_name . c_str ( ) , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
if ( out_name_size > 0 ) {
CHECK_EQ ( out_name_size , out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_map [ out_names [ i ] ] = NDArray ( out_arr [ i ] ) ;
}
}
return array_map ;
}
2017-03-28 23:34:46 -05:00
inline std : : vector < NDArray > NDArray : : LoadToList ( const std : : string & file_name ) {
2017-03-22 11:55:51 +08:00
std : : vector < NDArray > array_list ;
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoad ( file_name . c_str ( ) , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
2018-04-03 20:55:35 +01:00
array_list . reserve ( out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_list . push_back ( NDArray ( out_arr [ i ] ) ) ;
}
return array_list ;
}
inline void NDArray : : LoadFromBuffer ( const void * buffer , size_t size ,
std : : vector < NDArray > * array_list ,
std : : map < std : : string , NDArray > * array_map ) {
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoadFromBuffer ( buffer , size , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
if ( array_list ! = nullptr ) {
array_list - > reserve ( out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_list - > push_back ( NDArray ( out_arr [ i ] ) ) ;
}
}
if ( array_map ! = nullptr & & out_name_size > 0 ) {
CHECK_EQ ( out_name_size , out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
( * array_map ) [ out_names [ i ] ] = NDArray ( out_arr [ i ] ) ;
}
}
}
inline std : : map < std : : string , NDArray > NDArray : : LoadFromBufferToMap (
const void * buffer , size_t size ) {
std : : map < std : : string , NDArray > array_map ;
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoadFromBuffer ( buffer , size , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
if ( out_name_size > 0 ) {
CHECK_EQ ( out_name_size , out_size ) ;
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_map [ out_names [ i ] ] = NDArray ( out_arr [ i ] ) ;
}
}
return array_map ;
}
inline std : : vector < NDArray > NDArray : : LoadFromBufferToList ( const void * buffer , size_t size ) {
std : : vector < NDArray > array_list ;
mx_uint out_size , out_name_size ;
NDArrayHandle * out_arr ;
const char * * out_names ;
CHECK_EQ ( MXNDArrayLoadFromBuffer ( buffer , size , & out_size , & out_arr , & out_name_size ,
& out_names ) ,
0 ) ;
array_list . reserve ( out_size ) ;
2017-03-22 11:55:51 +08:00
for ( mx_uint i = 0 ; i < out_size ; + + i ) {
array_list . push_back ( NDArray ( out_arr [ i ] ) ) ;
}
return array_list ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : Save ( const std : : string & file_name ,
const std : : map < std : : string , NDArray > & array_map ) {
2017-03-22 11:55:51 +08:00
std : : vector < NDArrayHandle > args ;
std : : vector < const char * > keys ;
for ( const auto & t : array_map ) {
args . push_back ( t . second . GetHandle ( ) ) ;
keys . push_back ( t . first . c_str ( ) ) ;
}
CHECK_EQ (
MXNDArraySave ( file_name . c_str ( ) , args . size ( ) , args . data ( ) , keys . data ( ) ) ,
0 ) ;
}
2017-03-28 23:34:46 -05:00
inline void NDArray : : Save ( const std : : string & file_name ,
const std : : vector < NDArray > & array_list ) {
2017-03-22 11:55:51 +08:00
std : : vector < NDArrayHandle > args ;
for ( const auto & t : array_list ) {
args . push_back ( t . GetHandle ( ) ) ;
}
CHECK_EQ ( MXNDArraySave ( file_name . c_str ( ) , args . size ( ) , args . data ( ) , nullptr ) ,
0 ) ;
}
2017-03-28 23:34:46 -05:00
inline size_t NDArray : : Offset ( size_t h , size_t w ) const {
2017-03-22 11:55:51 +08:00
return ( h * GetShape ( ) [ 1 ] ) + w ;
}
2017-03-28 23:34:46 -05:00
inline size_t NDArray : : Offset ( size_t c , size_t h , size_t w ) const {
2017-03-22 11:55:51 +08:00
auto const shape = GetShape ( ) ;
return h * shape [ 0 ] * shape [ 2 ] + w * shape [ 0 ] + c ;
}
2017-03-28 23:34:46 -05:00
inline mx_float NDArray : : At ( size_t h , size_t w ) const {
2017-03-22 11:55:51 +08:00
return GetData ( ) [ Offset ( h , w ) ] ;
}
2017-03-28 23:34:46 -05:00
inline mx_float NDArray : : At ( size_t c , size_t h , size_t w ) const {
2017-03-22 11:55:51 +08:00
return GetData ( ) [ Offset ( c , h , w ) ] ;
}
2017-03-28 23:34:46 -05:00
inline size_t NDArray : : Size ( ) const {
2017-03-22 11:55:51 +08:00
size_t ret = 1 ;
for ( auto & i : GetShape ( ) ) ret * = i ;
return ret ;
}
2017-03-28 23:34:46 -05:00
inline std : : vector < mx_uint > NDArray : : GetShape ( ) const {
2017-03-22 11:55:51 +08:00
const mx_uint * out_pdata ;
mx_uint out_dim ;
MXNDArrayGetShape ( blob_ptr_ - > handle_ , & out_dim , & out_pdata ) ;
std : : vector < mx_uint > ret ;
for ( mx_uint i = 0 ; i < out_dim ; + + i ) {
ret . push_back ( out_pdata [ i ] ) ;
}
return ret ;
}
2017-04-01 07:13:44 +09:00
inline int NDArray : : GetDType ( ) const {
int ret ;
2017-04-04 13:51:57 -04:00
MXNDArrayGetDType ( blob_ptr_ - > handle_ , & ret ) ;
2017-04-01 07:13:44 +09:00
return ret ;
}
2017-03-28 23:34:46 -05:00
inline const mx_float * NDArray : : GetData ( ) const {
2017-04-01 07:13:44 +09:00
void * ret ;
2017-03-22 11:55:51 +08:00
MXNDArrayGetData ( blob_ptr_ - > handle_ , & ret ) ;
2017-04-01 07:13:44 +09:00
if ( GetDType ( ) ! = 0 ) {
return NULL ;
}
return static_cast < mx_float * > ( ret ) ;
2017-03-22 11:55:51 +08:00
}
2017-04-01 07:13:44 +09:00
2017-03-28 23:34:46 -05:00
inline Context NDArray : : GetContext ( ) const {
2017-03-22 11:55:51 +08:00
int out_dev_type ;
int out_dev_id ;
MXNDArrayGetContext ( blob_ptr_ - > handle_ , & out_dev_type , & out_dev_id ) ;
return Context ( ( DeviceType ) out_dev_type , out_dev_id ) ;
}
2017-05-18 00:57:37 +08:00
inline std : : ostream & operator < < ( std : : ostream & out , const NDArray & ndarray ) {
// TODO(lx75249): Consider DType / beautify like numpy
auto shape = ndarray . GetShape ( ) ;
NDArray cpu_array ( ndarray . GetShape ( ) , Context : : cpu ( ) ) ;
if ( ndarray . GetContext ( ) . GetDeviceType ( ) ! = DeviceType : : kGPU ) {
cpu_array = ndarray ;
} else {
ndarray . WaitToRead ( ) ;
ndarray . CopyTo ( & cpu_array ) ;
}
out < < ' [ ' ;
cpu_array . WaitToRead ( ) ;
std : : copy ( cpu_array . GetData ( ) , cpu_array . GetData ( ) + ndarray . Size ( ) ,
std : : ostream_iterator < float > ( out , " , " ) ) ;
out < < ' ] ' ;
return out ;
}
2017-03-22 11:55:51 +08:00
} // namespace cpp
} // namespace mxnet
2017-07-12 10:04:40 -07:00
# endif // MXNET_CPP_NDARRAY_HPP_