2021-05-24 13:44:39 -07:00
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
2021-11-19 09:27:00 +01:00
* \file shape.h
* \brief definition of shape
* \author Chuntao Hong, Zhang Chen
*/
2021-05-24 13:44:39 -07:00
# ifndef MXNET_CPP_SHAPE_H_
# define MXNET_CPP_SHAPE_H_
# include <istream>
# include <ostream>
# include <algorithm>
# include <vector>
# include "mxnet-cpp/base.h"
namespace mxnet {
namespace cpp {
/*!
2021-11-19 09:27:00 +01:00
* \brief dynamic shape class that can hold shape
* of arbirary dimension
*/
2021-05-24 13:44:39 -07:00
struct Shape {
public :
/*! \brief constructor */
2021-11-19 09:27:00 +01:00
Shape ( ) : ndim_ ( 0 ) , num_heap_allocated_ ( 0 ) , data_heap_ ( nullptr ) { }
2021-05-24 13:44:39 -07:00
/*!
2021-11-19 09:27:00 +01:00
* \brief constructor from a vector of index_t
* \param v the vector
*/
explicit Shape ( const std : : vector < index_t > & v ) : ndim_ ( v . size ( ) ) {
2021-05-24 13:44:39 -07:00
if ( ndim_ < = kStackCache ) {
2021-11-19 09:27:00 +01:00
data_heap_ = nullptr ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = 0 ;
std : : copy ( v . begin ( ) , v . end ( ) , data_stack_ ) ;
} else {
2021-11-19 09:27:00 +01:00
data_heap_ = new index_t [ ndim_ ] ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = ndim_ ;
std : : copy ( v . begin ( ) , v . end ( ) , data_heap_ ) ;
}
}
/*!
2021-11-19 09:27:00 +01:00
* \brief constructor one dimmension shape
* \param s1 size of the first dimmension
*/
explicit Shape ( index_t s1 ) : ndim_ ( 1 ) {
2021-05-24 13:44:39 -07:00
if ( ndim_ < = kStackCache ) {
2021-11-19 09:27:00 +01:00
data_heap_ = nullptr ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = 0 ;
2021-11-19 09:27:00 +01:00
data_stack_ [ 0 ] = s1 ;
2021-05-24 13:44:39 -07:00
} else {
2021-11-19 09:27:00 +01:00
data_heap_ = new index_t [ ndim_ ] ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = ndim_ ;
2021-11-19 09:27:00 +01:00
data_heap_ [ 0 ] = s1 ;
2021-05-24 13:44:39 -07:00
}
}
/*!
2021-11-19 09:27:00 +01:00
* \brief constructor two dimmension shape
* \param s1 size of the first dimmension
* \param s2 size of the second dimmension
*/
Shape ( index_t s1 , index_t s2 ) : ndim_ ( 2 ) {
2021-05-24 13:44:39 -07:00
if ( ndim_ < = kStackCache ) {
2021-11-19 09:27:00 +01:00
data_heap_ = nullptr ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = 0 ;
2021-11-19 09:27:00 +01:00
data_stack_ [ 0 ] = s1 ;
data_stack_ [ 1 ] = s2 ;
2021-05-24 13:44:39 -07:00
} else {
2021-11-19 09:27:00 +01:00
data_heap_ = new index_t [ ndim_ ] ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = ndim_ ;
2021-11-19 09:27:00 +01:00
data_heap_ [ 0 ] = s1 ;
data_heap_ [ 1 ] = s2 ;
2021-05-24 13:44:39 -07:00
}
}
/*!
2021-11-19 09:27:00 +01:00
* \brief constructor three dimmension shape
* \param s1 size of the first dimmension
* \param s2 size of the second dimmension
* \param s3 size of the third dimmension
*/
Shape ( index_t s1 , index_t s2 , index_t s3 ) : ndim_ ( 3 ) {
2021-05-24 13:44:39 -07:00
if ( ndim_ < = kStackCache ) {
2021-11-19 09:27:00 +01:00
data_heap_ = nullptr ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = 0 ;
2021-11-19 09:27:00 +01:00
data_stack_ [ 0 ] = s1 ;
data_stack_ [ 1 ] = s2 ;
data_stack_ [ 2 ] = s3 ;
2021-05-24 13:44:39 -07:00
} else {
2021-11-19 09:27:00 +01:00
data_heap_ = new index_t [ ndim_ ] ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = ndim_ ;
2021-11-19 09:27:00 +01:00
data_heap_ [ 0 ] = s1 ;
data_heap_ [ 1 ] = s2 ;
data_heap_ [ 2 ] = s3 ;
2021-05-24 13:44:39 -07:00
}
}
/*!
2021-11-19 09:27:00 +01:00
* \brief constructor four dimmension shape
* \param s1 size of the first dimmension
* \param s2 size of the second dimmension
* \param s3 size of the third dimmension
* \param s4 size of the fourth dimmension
*/
Shape ( index_t s1 , index_t s2 , index_t s3 , index_t s4 ) : ndim_ ( 4 ) {
2021-05-24 13:44:39 -07:00
if ( ndim_ < = kStackCache ) {
2021-11-19 09:27:00 +01:00
data_heap_ = nullptr ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = 0 ;
2021-11-19 09:27:00 +01:00
data_stack_ [ 0 ] = s1 ;
data_stack_ [ 1 ] = s2 ;
data_stack_ [ 2 ] = s3 ;
data_stack_ [ 3 ] = s4 ;
2021-05-24 13:44:39 -07:00
} else {
2021-11-19 09:27:00 +01:00
data_heap_ = new index_t [ ndim_ ] ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = ndim_ ;
2021-11-19 09:27:00 +01:00
data_heap_ [ 0 ] = s1 ;
data_heap_ [ 1 ] = s2 ;
data_heap_ [ 2 ] = s3 ;
data_heap_ [ 3 ] = s4 ;
2021-05-24 13:44:39 -07:00
}
}
/*!
2021-11-19 09:27:00 +01:00
* \brief constructor five dimmension shape
* \param s1 size of the first dimmension
* \param s2 size of the second dimmension
* \param s3 size of the third dimmension
* \param s4 size of the fourth dimmension
* \param s5 size of the fifth dimmension
*/
Shape ( index_t s1 , index_t s2 , index_t s3 , index_t s4 , index_t s5 ) : ndim_ ( 5 ) {
2021-05-24 13:44:39 -07:00
if ( ndim_ < = kStackCache ) {
2021-11-19 09:27:00 +01:00
data_heap_ = nullptr ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = 0 ;
2021-11-19 09:27:00 +01:00
data_stack_ [ 0 ] = s1 ;
data_stack_ [ 1 ] = s2 ;
data_stack_ [ 2 ] = s3 ;
data_stack_ [ 3 ] = s4 ;
data_stack_ [ 4 ] = s5 ;
2021-05-24 13:44:39 -07:00
} else {
2021-11-19 09:27:00 +01:00
data_heap_ = new index_t [ ndim_ ] ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = ndim_ ;
2021-11-19 09:27:00 +01:00
data_heap_ [ 0 ] = s1 ;
data_heap_ [ 1 ] = s2 ;
data_heap_ [ 2 ] = s3 ;
data_heap_ [ 3 ] = s4 ;
data_heap_ [ 4 ] = s5 ;
2021-05-24 13:44:39 -07:00
}
}
/*!
2021-11-19 09:27:00 +01:00
* \brief constructor from Shape
* \param s the source shape
*/
Shape ( const Shape & s ) : ndim_ ( s . ndim_ ) {
2021-05-24 13:44:39 -07:00
if ( ndim_ < = kStackCache ) {
2021-11-19 09:27:00 +01:00
data_heap_ = nullptr ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = 0 ;
std : : copy ( s . data_stack_ , s . data_stack_ + ndim_ , data_stack_ ) ;
} else {
2021-11-19 09:27:00 +01:00
data_heap_ = new index_t [ ndim_ ] ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = ndim_ ;
std : : copy ( s . data_heap_ , s . data_heap_ + ndim_ , data_heap_ ) ;
}
}
# if MSHADOW_IN_CXX11
/*!
2021-11-19 09:27:00 +01:00
* \brief move constructor from Shape
* \param s the source shape
*/
Shape ( Shape & & s )
: ndim_ ( s . ndim_ ) , num_heap_allocated_ ( s . num_heap_allocated_ ) , data_heap_ ( s . data_heap_ ) {
2021-05-24 13:44:39 -07:00
if ( ndim_ < = kStackCache ) {
std : : copy ( s . data_stack_ , s . data_stack_ + ndim_ , data_stack_ ) ;
}
// remove data heap space from s
s . data_heap_ = nullptr ;
}
# endif
/*! \brief destructor */
~ Shape ( ) {
// data_heap_ can be nullptr
delete [ ] data_heap_ ;
}
/*!
2021-11-19 09:27:00 +01:00
* \brief copy shape from content betwen two iterators
* \param begin the beginning of iterator
* \param end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
template < typename RandomAccessIterator >
inline void CopyFrom ( RandomAccessIterator begin , RandomAccessIterator end ) {
2021-05-24 13:44:39 -07:00
this - > SetDim ( end - begin ) ;
std : : copy ( begin , end , data ( ) ) ;
}
/*!
2021-11-19 09:27:00 +01:00
* \brief assignment from shape
* \param shape source shape
* \return reference of self
*/
inline Shape & operator = ( const Shape & shape ) {
2021-05-24 13:44:39 -07:00
this - > SetDim ( shape . ndim_ ) ;
2021-11-19 09:27:00 +01:00
const index_t * src = shape . data ( ) ;
2021-05-24 13:44:39 -07:00
std : : copy ( src , src + ndim_ , data ( ) ) ;
return * this ;
}
/*!
2021-11-19 09:27:00 +01:00
* \brief assignment from vector
* \param shape source shape
* \return reference of self
*/
inline Shape & operator = ( const std : : vector < index_t > & shape ) {
2021-05-24 13:44:39 -07:00
this - > CopyFrom ( shape . begin ( ) , shape . end ( ) ) ;
return * this ;
}
/*! \return the data content of the shape */
2021-11-19 09:27:00 +01:00
inline const index_t * data ( ) const {
2021-05-24 13:44:39 -07:00
return ndim_ < = kStackCache ? data_stack_ : data_heap_ ;
}
/*! \return the data content of the shape */
2021-11-19 09:27:00 +01:00
inline index_t * data ( ) {
2021-05-24 13:44:39 -07:00
return ndim_ < = kStackCache ? data_stack_ : data_heap_ ;
}
/*! \brief return number of dimension of the tensor inside */
inline index_t ndim ( void ) const {
return ndim_ ;
}
/*!
2021-11-19 09:27:00 +01:00
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline index_t & operator [ ] ( index_t i ) {
2021-05-24 13:44:39 -07:00
return data ( ) [ i ] ;
}
/*!
2021-11-19 09:27:00 +01:00
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline const index_t & operator [ ] ( index_t i ) const {
2021-05-24 13:44:39 -07:00
return data ( ) [ i ] ;
}
/*! \brief total number of elements in the tensor */
inline size_t Size ( void ) const {
2021-11-19 09:27:00 +01:00
size_t size = 1 ;
const index_t * d = this - > data ( ) ;
2021-05-24 13:44:39 -07:00
for ( index_t i = 0 ; i < ndim_ ; + + i ) {
size * = d [ i ] ;
}
return size ;
}
/*!
2021-11-19 09:27:00 +01:00
* \return whether two shape equals
* \param s the shape to compare against
*/
inline bool operator = = ( const Shape & s ) const {
if ( ndim_ ! = s . ndim_ )
return false ;
2021-05-24 13:44:39 -07:00
if ( ndim_ < = kStackCache ) {
for ( index_t i = 0 ; i < ndim_ ; + + i ) {
2021-11-19 09:27:00 +01:00
if ( data_stack_ [ i ] ! = s . data_stack_ [ i ] )
return false ;
2021-05-24 13:44:39 -07:00
}
} else {
for ( index_t i = 0 ; i < ndim_ ; + + i ) {
2021-11-19 09:27:00 +01:00
if ( data_heap_ [ i ] ! = s . data_heap_ [ i ] )
return false ;
2021-05-24 13:44:39 -07:00
}
}
return true ;
}
/*!
2021-11-19 09:27:00 +01:00
* \return whether two shape not equals
* \param s the shape to compare against
*/
inline bool operator ! = ( const Shape & s ) const {
2021-05-24 13:44:39 -07:00
return ! ( * this = = s ) ;
}
2021-11-19 09:27:00 +01:00
friend std : : ostream & operator < < ( std : : ostream & os , const Shape & shape ) ;
friend std : : istream & operator > > ( std : : istream & is , Shape & shape ) ;
2021-05-24 13:44:39 -07:00
private :
// the shape will be stored in data_stack_
// when dimension is smaller than kStackCache
// when it is bigger, it will be stored in data_heap_;
/*! \brief size of in stack space */
static const index_t kStackCache = 5 ;
/*! \brief number of dimnsion of the shape */
index_t ndim_ ;
/*! \brief number of cells allocated in data_heap_ */
index_t num_heap_allocated_ ;
/*! \brief in stack space used to store shape when it is small */
index_t data_stack_ [ kStackCache ] ;
/*! \brief space to store shape when dimension is big*/
2021-11-19 09:27:00 +01:00
index_t * data_heap_ ;
2021-05-24 13:44:39 -07:00
/*!
2021-11-19 09:27:00 +01:00
* \brief internal function to set the dimension
* \param dim the dimension of the shape
*/
2021-05-24 13:44:39 -07:00
inline void SetDim ( index_t dim ) {
2021-11-19 09:27:00 +01:00
if ( dim > kStackCache & & dim > num_heap_allocated_ ) {
2021-05-24 13:44:39 -07:00
// data_heap_ can be nullptr
delete [ ] data_heap_ ;
2021-11-19 09:27:00 +01:00
data_heap_ = new index_t [ dim ] ;
2021-05-24 13:44:39 -07:00
num_heap_allocated_ = dim ;
}
ndim_ = dim ;
}
} ;
/*!
2021-11-19 09:27:00 +01:00
* \brief allow string printing of the shape
* \param os the output stream
* \param shape the shape
* \return the ostream
*/
inline std : : ostream & operator < < ( std : : ostream & os , const Shape & shape ) {
2021-05-24 13:44:39 -07:00
os < < ' ( ' ;
for ( index_t i = 0 ; i < shape . ndim ( ) ; + + i ) {
2021-11-19 09:27:00 +01:00
if ( i ! = 0 )
os < < ' , ' ;
2021-05-24 13:44:39 -07:00
os < < static_cast < int > ( shape [ i ] ) ; // Supports negative Shape 'special codes' for inferring
}
// python style tuple
2021-11-19 09:27:00 +01:00
if ( shape . ndim ( ) = = 1 )
os < < ' , ' ;
2021-05-24 13:44:39 -07:00
os < < ' ) ' ;
return os ;
}
/*!
2021-11-19 09:27:00 +01:00
* \brief read shape from the istream
* \param is the input stream
* \param shape the shape
* \return the istream
*/
inline std : : istream & operator > > ( std : : istream & is , Shape & shape ) {
2021-05-24 13:44:39 -07:00
// get (
while ( true ) {
char ch = is . get ( ) ;
2021-11-19 09:27:00 +01:00
if ( ch = = ' ( ' )
break ;
2021-05-24 13:44:39 -07:00
if ( ! isspace ( ch ) ) {
is . setstate ( std : : ios : : failbit ) ;
return is ;
}
}
index_t idx ;
std : : vector < index_t > tmp ;
while ( is > > idx ) {
tmp . push_back ( idx ) ;
char ch ;
do {
ch = is . get ( ) ;
} while ( isspace ( ch ) ) ;
if ( ch = = ' , ' ) {
while ( true ) {
ch = is . peek ( ) ;
if ( isspace ( ch ) ) {
2021-11-19 09:27:00 +01:00
is . get ( ) ;
continue ;
2021-05-24 13:44:39 -07:00
}
if ( ch = = ' ) ' ) {
2021-11-19 09:27:00 +01:00
is . get ( ) ;
break ;
2021-05-24 13:44:39 -07:00
}
break ;
}
2021-11-19 09:27:00 +01:00
if ( ch = = ' ) ' )
break ;
2021-05-24 13:44:39 -07:00
} else if ( ch = = ' ) ' ) {
break ;
} else {
is . setstate ( std : : ios : : failbit ) ;
return is ;
}
}
shape . CopyFrom ( tmp . begin ( ) , tmp . end ( ) ) ;
return is ;
}
} // namespace cpp
} // namespace mxnet
# endif // MXNET_CPP_SHAPE_H_