/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file initializer.h * \brief random initializer * \author Zhang Chen */ #ifndef MXNET_CPP_INITIALIZER_H_ #define MXNET_CPP_INITIALIZER_H_ #include #include #include #include #include "mxnet-cpp/ndarray.h" namespace mxnet { namespace cpp { class Initializer { public: static bool StringStartWith(const std::string& name, const std::string& check_str) { return (name.size() >= check_str.size() && name.substr(0, check_str.size()) == check_str); } static bool StringEndWith(const std::string& name, const std::string& check_str) { return (name.size() >= check_str.size() && name.substr(name.size() - check_str.size(), check_str.size()) == check_str); } virtual void operator()(const std::string& name, NDArray* arr) { if (StringStartWith(name, "upsampling")) { InitBilinear(arr); } else if (StringEndWith(name, "bias")) { InitBias(arr); } else if (StringEndWith(name, "gamma")) { InitGamma(arr); } else if (StringEndWith(name, "beta")) { InitBeta(arr); } else if (StringEndWith(name, "weight")) { InitWeight(arr); } else if (StringEndWith(name, "moving_mean")) { InitZero(arr); } else if (StringEndWith(name, "moving_var")) { InitOne(arr); } else if (StringEndWith(name, "moving_inv_var")) { InitZero(arr); } else if (StringEndWith(name, "moving_avg")) { InitZero(arr); } else if (StringEndWith(name, "min")) { InitZero(arr); } else if (StringEndWith(name, "max")) { InitOne(arr); } else if (StringEndWith(name, "weight_quantize")) { InitQuantizedWeight(arr); } else if (StringEndWith(name, "bias_quantize")) { InitQuantizedBias(arr); } else { InitDefault(arr); } } protected: virtual void InitBilinear(NDArray* arr) { Shape shape(arr->GetShape()); std::vector weight(shape.Size(), 0); int f = std::ceil(shape[3] / 2.0); float c = (2 * f - 1 - f % 2) / (2. * f); for (size_t i = 0; i < shape.Size(); ++i) { int x = i % shape[3]; int y = (i / shape[3]) % shape[2]; weight[i] = (1 - std::abs(x / f - c)) * (1 - std::abs(y / f - c)); } (*arr).SyncCopyFromCPU(weight); } virtual void InitZero(NDArray* arr) { (*arr) = 0.0f; } virtual void InitOne(NDArray* arr) { (*arr) = 1.0f; } virtual void InitBias(NDArray* arr) { (*arr) = 0.0f; } virtual void InitGamma(NDArray* arr) { (*arr) = 1.0f; } virtual void InitBeta(NDArray* arr) { (*arr) = 0.0f; } virtual void InitWeight(NDArray* arr) {} virtual void InitQuantizedWeight(NDArray* arr) { std::default_random_engine generator; std::uniform_int_distribution _val(-127, 127); (*arr) = _val(generator); } virtual void InitQuantizedBias(NDArray* arr) { (*arr) = 0; } virtual void InitDefault(NDArray* arr) {} }; class Constant : public Initializer { public: explicit Constant(float value) : value(value) {} void operator()(const std::string& name, NDArray* arr) override { (*arr) = value; } protected: float value; }; class Zero : public Constant { public: Zero() : Constant(0.0f) {} }; class One : public Constant { public: One() : Constant(1.0f) {} }; class Uniform : public Initializer { public: explicit Uniform(float scale) : Uniform(-scale, scale) {} Uniform(float begin, float end) : begin(begin), end(end) {} void operator()(const std::string& name, NDArray* arr) override { if (StringEndWith(name, "weight_quantize")) { InitQuantizedWeight(arr); return; } if (StringEndWith(name, "bias_quantize")) { InitQuantizedBias(arr); return; } NDArray::SampleUniform(begin, end, arr); } protected: float begin, end; }; class Normal : public Initializer { public: Normal(float mu, float sigma) : mu(mu), sigma(sigma) {} void operator()(const std::string& name, NDArray* arr) override { if (StringEndWith(name, "weight_quantize")) { InitQuantizedWeight(arr); return; } if (StringEndWith(name, "bias_quantize")) { InitQuantizedBias(arr); return; } NDArray::SampleGaussian(mu, sigma, arr); } protected: float mu, sigma; }; class Bilinear : public Initializer { public: Bilinear() {} void operator()(const std::string& name, NDArray* arr) override { if (StringEndWith(name, "weight_quantize")) { InitQuantizedWeight(arr); return; } if (StringEndWith(name, "bias_quantize")) { InitQuantizedBias(arr); return; } InitBilinear(arr); } }; class Xavier : public Initializer { public: enum RandType { gaussian, uniform } rand_type; enum FactorType { avg, in, out } factor_type; float magnitude; Xavier(RandType rand_type = gaussian, // NOLINT FactorType factor_type = avg, // NOLINT float magnitude = 3) // NOLINT : rand_type(rand_type), factor_type(factor_type), magnitude(magnitude) {} void operator()(const std::string& name, NDArray* arr) override { if (StringEndWith(name, "weight_quantize")) { InitQuantizedWeight(arr); return; } if (StringEndWith(name, "bias_quantize")) { InitQuantizedBias(arr); return; } Shape shape(arr->GetShape()); float hw_scale = 1.0f; if (shape.ndim() > 2) { for (size_t i = 2; i < shape.ndim(); ++i) { hw_scale *= shape[i]; } } float fan_in = shape[1] * hw_scale, fan_out = shape[0] * hw_scale; float factor = 1.0f; switch (factor_type) { case avg: factor = (fan_in + fan_out) / 2.0; break; case in: factor = fan_in; break; case out: factor = fan_out; } float scale = std::sqrt(magnitude / factor); switch (rand_type) { case uniform: NDArray::SampleUniform(-scale, scale, arr); break; case gaussian: NDArray::SampleGaussian(0, scale, arr); break; } } }; class MSRAPrelu : public Xavier { public: explicit MSRAPrelu(FactorType factor_type = avg, float slope = 0.25f) : Xavier(gaussian, factor_type, 2. / (1 + slope * slope)) {} }; } // namespace cpp } // namespace mxnet #endif // MXNET_CPP_INITIALIZER_H_