2021-05-24 13:44:39 -07:00
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file initializer.h
* \brief random initializer
* \author Zhang Chen
*/
# ifndef MXNET_CPP_INITIALIZER_H_
# define MXNET_CPP_INITIALIZER_H_
# include <cmath>
# include <string>
# include <vector>
# include <random>
# include "mxnet-cpp/ndarray.h"
namespace mxnet {
namespace cpp {
class Initializer {
public :
2021-11-19 09:27:00 +01:00
static bool StringStartWith ( const std : : string & name , const std : : string & check_str ) {
return ( name . size ( ) > = check_str . size ( ) & & name . substr ( 0 , check_str . size ( ) ) = = check_str ) ;
2021-05-24 13:44:39 -07:00
}
2021-11-19 09:27:00 +01:00
static bool StringEndWith ( const std : : string & name , const std : : string & check_str ) {
2021-05-24 13:44:39 -07:00
return ( name . size ( ) > = check_str . size ( ) & &
2021-11-19 09:27:00 +01:00
name . substr ( name . size ( ) - check_str . size ( ) , check_str . size ( ) ) = = check_str ) ;
2021-05-24 13:44:39 -07:00
}
virtual void operator ( ) ( const std : : string & name , NDArray * arr ) {
if ( StringStartWith ( name , " upsampling " ) ) {
InitBilinear ( arr ) ;
} else if ( StringEndWith ( name , " bias " ) ) {
InitBias ( arr ) ;
} else if ( StringEndWith ( name , " gamma " ) ) {
InitGamma ( arr ) ;
} else if ( StringEndWith ( name , " beta " ) ) {
InitBeta ( arr ) ;
} else if ( StringEndWith ( name , " weight " ) ) {
InitWeight ( arr ) ;
} else if ( StringEndWith ( name , " moving_mean " ) ) {
InitZero ( arr ) ;
} else if ( StringEndWith ( name , " moving_var " ) ) {
InitOne ( arr ) ;
} else if ( StringEndWith ( name , " moving_inv_var " ) ) {
InitZero ( arr ) ;
} else if ( StringEndWith ( name , " moving_avg " ) ) {
InitZero ( arr ) ;
} else if ( StringEndWith ( name , " min " ) ) {
InitZero ( arr ) ;
} else if ( StringEndWith ( name , " max " ) ) {
InitOne ( arr ) ;
} else if ( StringEndWith ( name , " weight_quantize " ) ) {
InitQuantizedWeight ( arr ) ;
} else if ( StringEndWith ( name , " bias_quantize " ) ) {
InitQuantizedBias ( arr ) ;
} else {
InitDefault ( arr ) ;
}
}
protected :
virtual void InitBilinear ( NDArray * arr ) {
Shape shape ( arr - > GetShape ( ) ) ;
std : : vector < float > weight ( shape . Size ( ) , 0 ) ;
2021-11-19 09:27:00 +01:00
int f = std : : ceil ( shape [ 3 ] / 2.0 ) ;
2021-05-24 13:44:39 -07:00
float c = ( 2 * f - 1 - f % 2 ) / ( 2. * f ) ;
for ( size_t i = 0 ; i < shape . Size ( ) ; + + i ) {
2021-11-19 09:27:00 +01:00
int x = i % shape [ 3 ] ;
int y = ( i / shape [ 3 ] ) % shape [ 2 ] ;
2021-05-24 13:44:39 -07:00
weight [ i ] = ( 1 - std : : abs ( x / f - c ) ) * ( 1 - std : : abs ( y / f - c ) ) ;
}
( * arr ) . SyncCopyFromCPU ( weight ) ;
}
2021-11-19 09:27:00 +01:00
virtual void InitZero ( NDArray * arr ) {
( * arr ) = 0.0f ;
}
virtual void InitOne ( NDArray * arr ) {
( * arr ) = 1.0f ;
}
virtual void InitBias ( NDArray * arr ) {
( * arr ) = 0.0f ;
}
virtual void InitGamma ( NDArray * arr ) {
( * arr ) = 1.0f ;
}
virtual void InitBeta ( NDArray * arr ) {
( * arr ) = 0.0f ;
}
2021-05-24 13:44:39 -07:00
virtual void InitWeight ( NDArray * arr ) { }
virtual void InitQuantizedWeight ( NDArray * arr ) {
std : : default_random_engine generator ;
std : : uniform_int_distribution < int32_t > _val ( - 127 , 127 ) ;
( * arr ) = _val ( generator ) ;
}
virtual void InitQuantizedBias ( NDArray * arr ) {
( * arr ) = 0 ;
}
virtual void InitDefault ( NDArray * arr ) { }
} ;
class Constant : public Initializer {
public :
2021-11-19 09:27:00 +01:00
explicit Constant ( float value ) : value ( value ) { }
void operator ( ) ( const std : : string & name , NDArray * arr ) override {
2021-05-24 13:44:39 -07:00
( * arr ) = value ;
}
2021-11-19 09:27:00 +01:00
2021-05-24 13:44:39 -07:00
protected :
float value ;
} ;
class Zero : public Constant {
public :
2021-11-19 09:27:00 +01:00
Zero ( ) : Constant ( 0.0f ) { }
2021-05-24 13:44:39 -07:00
} ;
class One : public Constant {
public :
2021-11-19 09:27:00 +01:00
One ( ) : Constant ( 1.0f ) { }
2021-05-24 13:44:39 -07:00
} ;
class Uniform : public Initializer {
public :
2021-11-19 09:27:00 +01:00
explicit Uniform ( float scale ) : Uniform ( - scale , scale ) { }
Uniform ( float begin , float end ) : begin ( begin ) , end ( end ) { }
void operator ( ) ( const std : : string & name , NDArray * arr ) override {
2021-05-24 13:44:39 -07:00
if ( StringEndWith ( name , " weight_quantize " ) ) {
InitQuantizedWeight ( arr ) ;
return ;
}
if ( StringEndWith ( name , " bias_quantize " ) ) {
InitQuantizedBias ( arr ) ;
return ;
}
NDArray : : SampleUniform ( begin , end , arr ) ;
}
2021-11-19 09:27:00 +01:00
2021-05-24 13:44:39 -07:00
protected :
float begin , end ;
} ;
class Normal : public Initializer {
public :
2021-11-19 09:27:00 +01:00
Normal ( float mu , float sigma ) : mu ( mu ) , sigma ( sigma ) { }
void operator ( ) ( const std : : string & name , NDArray * arr ) override {
2021-05-24 13:44:39 -07:00
if ( StringEndWith ( name , " weight_quantize " ) ) {
InitQuantizedWeight ( arr ) ;
return ;
}
if ( StringEndWith ( name , " bias_quantize " ) ) {
InitQuantizedBias ( arr ) ;
return ;
}
NDArray : : SampleGaussian ( mu , sigma , arr ) ;
}
2021-11-19 09:27:00 +01:00
2021-05-24 13:44:39 -07:00
protected :
float mu , sigma ;
} ;
class Bilinear : public Initializer {
public :
Bilinear ( ) { }
2021-11-19 09:27:00 +01:00
void operator ( ) ( const std : : string & name , NDArray * arr ) override {
2021-05-24 13:44:39 -07:00
if ( StringEndWith ( name , " weight_quantize " ) ) {
InitQuantizedWeight ( arr ) ;
return ;
}
if ( StringEndWith ( name , " bias_quantize " ) ) {
InitQuantizedBias ( arr ) ;
return ;
}
InitBilinear ( arr ) ;
}
} ;
class Xavier : public Initializer {
public :
2021-11-19 09:27:00 +01:00
enum RandType { gaussian , uniform } rand_type ;
enum FactorType { avg , in , out } factor_type ;
2021-05-24 13:44:39 -07:00
float magnitude ;
2021-11-19 09:27:00 +01:00
Xavier ( RandType rand_type = gaussian , // NOLINT
FactorType factor_type = avg , // NOLINT
float magnitude = 3 ) // NOLINT
2021-05-24 13:44:39 -07:00
: rand_type ( rand_type ) , factor_type ( factor_type ) , magnitude ( magnitude ) { }
2021-11-19 09:27:00 +01:00
void operator ( ) ( const std : : string & name , NDArray * arr ) override {
2021-05-24 13:44:39 -07:00
if ( StringEndWith ( name , " weight_quantize " ) ) {
InitQuantizedWeight ( arr ) ;
return ;
}
if ( StringEndWith ( name , " bias_quantize " ) ) {
InitQuantizedBias ( arr ) ;
return ;
}
Shape shape ( arr - > GetShape ( ) ) ;
float hw_scale = 1.0f ;
if ( shape . ndim ( ) > 2 ) {
for ( size_t i = 2 ; i < shape . ndim ( ) ; + + i ) {
hw_scale * = shape [ i ] ;
}
}
float fan_in = shape [ 1 ] * hw_scale , fan_out = shape [ 0 ] * hw_scale ;
float factor = 1.0f ;
switch ( factor_type ) {
case avg :
factor = ( fan_in + fan_out ) / 2.0 ;
break ;
case in :
factor = fan_in ;
break ;
case out :
factor = fan_out ;
}
float scale = std : : sqrt ( magnitude / factor ) ;
switch ( rand_type ) {
case uniform :
NDArray : : SampleUniform ( - scale , scale , arr ) ;
break ;
case gaussian :
NDArray : : SampleGaussian ( 0 , scale , arr ) ;
break ;
}
}
} ;
class MSRAPrelu : public Xavier {
public :
explicit MSRAPrelu ( FactorType factor_type = avg , float slope = 0.25f )
: Xavier ( gaussian , factor_type , 2. / ( 1 + slope * slope ) ) { }
} ;
} // namespace cpp
} // namespace mxnet
# endif // MXNET_CPP_INITIALIZER_H_