/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file optimizer.hpp * \brief implementation of optimizer * \author Chuntao Hong, Zhang Chen */ #ifndef MXNET_CPP_OPTIMIZER_HPP_ #define MXNET_CPP_OPTIMIZER_HPP_ #include #include #include #include #include #include #include #include #include "mxnet-cpp/optimizer.h" #include "mxnet-cpp/op.h" #include "mxnet-cpp/op_map.h" namespace { // TODO(lx75249): Add imperative operators to op.h under ndarray namespace inline void _clip(mxnet::cpp::NDArray &data, float limit) { data = mxnet::cpp::Operator("clip") .SetParam("a_min", -limit) .SetParam("a_max", limit) .SetInput("data", data) .Invoke()[0]; } inline mxnet::cpp::NDArray _sqrt(mxnet::cpp::NDArray data) { return mxnet::cpp::Operator("sqrt") .SetInput("data", data) .Invoke()[0]; } } // namespace namespace mxnet { namespace cpp { inline Optimizer::Optimizer(unsigned begin_num_update) : begin_num_update_(begin_num_update), num_update_(begin_num_update_) { params_["lr"] = "0.01f"; params_["wd"] = "0.f"; } inline std::map& OptimizerRegistry::cmap() { static std::map cmap_; return cmap_; } inline OpMap*& Optimizer::op_map() { static OpMap *op_map_ = new OpMap(); return op_map_; } inline Optimizer::~Optimizer() {} inline void Optimizer::CreateState_(int index, NDArray weight) { } inline std::string Optimizer::Serialize() const { using ValueType = std::map::value_type; auto params = params_; params.emplace("opt_type", GetType()); return std::accumulate(params.cbegin(), params.cend(), std::string(""), [](const std::string& sum, const ValueType& i) { return sum + '\n' + i.first + '=' + i.second; }).substr(1); } inline const std::vector Optimizer::GetParamKeys_() const { std::vector keys; for (auto& iter : params_) keys.push_back(iter.first.c_str()); return keys; } inline const std::vector Optimizer::GetParamValues_() const { std::vector values; for (auto& iter : params_) values.push_back(iter.second.c_str()); return values; } inline unsigned Optimizer::UpdateCount_(int index) { if (count_.count(index) == 0) { count_.emplace(index, begin_num_update_); } unsigned new_count = ++count_[index]; num_update_ = std::max(num_update_, new_count); return new_count; } inline float Optimizer::GetLR_(int index) { if (nullptr != lrScheduler_) { return lrScheduler_->GetLR(num_update_); } return dmlc::stof(params_["lr"]); } inline float Optimizer::GetWD_(int index) { float wd = dmlc::stof(params_["wd"]); return wd; } inline Optimizer* OptimizerRegistry::Find(const std::string& name) { if (cmap().empty()) { // Optimizers should only be registered once MXNETCPP_REGISTER_OPTIMIZER(sgd, SGDOptimizer); MXNETCPP_REGISTER_OPTIMIZER(ccsgd, SGDOptimizer); // For backward compatibility MXNETCPP_REGISTER_OPTIMIZER(rmsprop, RMSPropOptimizer); MXNETCPP_REGISTER_OPTIMIZER(adam, AdamOptimizer); MXNETCPP_REGISTER_OPTIMIZER(adagrad, AdaGradOptimizer); MXNETCPP_REGISTER_OPTIMIZER(adadelta, AdaDeltaOptimizer); MXNETCPP_REGISTER_OPTIMIZER(signum, SignumOptimizer); } auto it = cmap().find(name); if (it == cmap().end()) return nullptr; return it->second(); } inline int OptimizerRegistry::__REGISTER__(const std::string& name, OptimizerCreator creator) { CHECK_EQ(cmap().count(name), 0) << name << " already registered"; cmap().emplace(name, std::move(creator)); return 0; } inline SGDOptimizer::SGDOptimizer(unsigned begin_num_update) : Optimizer(begin_num_update) { update_handle_ = op_map()->GetSymbolCreator("sgd_update"); mom_update_handle_ = op_map()->GetSymbolCreator("sgd_mom_update"); } inline std::string SGDOptimizer::GetType() const { return "sgd"; } inline SGDOptimizer::~SGDOptimizer() { for (auto &it : states_) { delete it.second; } } inline void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) { if (states_.count(index) == 0) { CreateState_(index, weight); } params_["lr"] = std::to_string(GetLR_(index)); params_["wd"] = std::to_string(GetWD_(index)); UpdateCount_(index); auto keys = GetParamKeys_(); auto values = GetParamValues_(); CHECK_EQ(keys.size(), values.size()); NDArrayHandle inputs[3]; inputs[0] = weight.GetHandle(); inputs[1] = grad.GetHandle(); int num_outputs = 1; NDArrayHandle output = weight.GetHandle(); NDArrayHandle *outputs = &output; if (states_[index] == nullptr) { MXImperativeInvoke(update_handle_, 2, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data(), nullptr); } else { inputs[2] = states_[index]->GetHandle(); MXImperativeInvoke(mom_update_handle_, 3, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data(), nullptr); } } inline void SGDOptimizer::CreateState_(int index, NDArray weight) { if (params_.count("momentum") == 0) { states_[index] = nullptr; } else { states_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *states_[index] = 0; } } // inplementing Signum optimizer inline SignumOptimizer::SignumOptimizer(unsigned begin_num_update) : Optimizer(begin_num_update) { update_handle_ = op_map()->GetSymbolCreator("signsgd_update"); mom_update_handle_ = op_map()->GetSymbolCreator("signum_update"); } inline std::string SignumOptimizer::GetType() const { return "signum"; } inline SignumOptimizer::~SignumOptimizer() { for (auto &it : states_) { delete it.second; } } inline void SignumOptimizer::Update(int index, NDArray weight, NDArray grad) { if (states_.count(index) == 0) { CreateState_(index, weight); } params_["lr"] = std::to_string(GetLR_(index)); params_["wd"] = std::to_string(GetWD_(index)); UpdateCount_(index); auto keys = GetParamKeys_(); auto values = GetParamValues_(); CHECK_EQ(keys.size(), values.size()); NDArrayHandle inputs[3]; inputs[0] = weight.GetHandle(); inputs[1] = grad.GetHandle(); int num_outputs = 1; NDArrayHandle output = weight.GetHandle(); NDArrayHandle *outputs = &output; if (states_[index] == nullptr) { MXImperativeInvoke(update_handle_, 2, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data(), nullptr); } else { inputs[2] = states_[index]->GetHandle(); MXImperativeInvoke(mom_update_handle_, 3, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data(), nullptr); } } inline void SignumOptimizer::CreateState_(int index, NDArray weight) { if (params_.count("momentum") == 0) { states_[index] = nullptr; } else { states_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *states_[index] = 0; } } // finish implementing Signum inline RMSPropOptimizer::RMSPropOptimizer(unsigned begin_num_update) : Optimizer(begin_num_update) { update_handle_ = op_map()->GetSymbolCreator("rmsprop_update"); alex_update_handle_ = op_map()->GetSymbolCreator("rmspropalex_update"); SetParam("gamma1", 0.9f); SetParam("gamma2", 0.9f); SetParam("epsilon", 1e-8); } inline std::string RMSPropOptimizer::GetType() const { return "rmsprop"; } inline RMSPropOptimizer::~RMSPropOptimizer() { for (auto &it : n_) { delete it.second; } for (auto &it : g_) { delete it.second; } for (auto &it : delta_) { delete it.second; } } inline void RMSPropOptimizer::Update(int index, NDArray weight, NDArray grad) { if (n_.count(index) == 0) { CreateState_(index, weight); } params_["lr"] = std::to_string(GetLR_(index)); params_["wd"] = std::to_string(GetWD_(index)); UpdateCount_(index); auto keys = GetParamKeys_(); auto values = GetParamValues_(); CHECK_EQ(keys.size(), values.size()); NDArrayHandle inputs[5]; inputs[0] = weight.GetHandle(); inputs[1] = grad.GetHandle(); inputs[2] = n_[index]->GetHandle(); inputs[3] = g_[index]->GetHandle(); inputs[4] = delta_[index]->GetHandle(); int num_outputs = 1; NDArrayHandle output = weight.GetHandle(); NDArrayHandle *outputs = &output; MXImperativeInvoke(alex_update_handle_, 5, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data(), nullptr); } inline void RMSPropOptimizer::CreateState_(int index, NDArray weight) { n_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *n_[index] = 0; g_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *g_[index] = 0; delta_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *delta_[index] = 0; } inline AdamOptimizer::AdamOptimizer(unsigned begin_num_update) : Optimizer(begin_num_update) { update_handle_ = op_map()->GetSymbolCreator("adam_update"); SetParam("beta1", 0.9f); SetParam("beta2", 0.999f); SetParam("epsilon", 1e-8); } inline std::string AdamOptimizer::GetType() const { return "adam"; } inline AdamOptimizer::~AdamOptimizer() { for (auto &it : mean_) { delete it.second; } for (auto &it : var_) { delete it.second; } } inline void AdamOptimizer::Update(int index, NDArray weight, NDArray grad) { if (mean_.count(index) == 0) { CreateState_(index, weight); } params_["lr"] = std::to_string(GetLR_(index)); params_["wd"] = std::to_string(GetWD_(index)); UpdateCount_(index); auto keys = GetParamKeys_(); auto values = GetParamValues_(); CHECK_EQ(keys.size(), values.size()); float lr = dmlc::stof(params_["lr"]); float b1 = dmlc::stof(params_["beta1"]); float b2 = dmlc::stof(params_["beta2"]); float t = count_[index]; float coef1 = 1.0f - std::pow(b1, t); float coef2 = 1.0f - std::pow(b2, t); lr *= std::sqrt(coef2) / coef1; NDArrayHandle inputs[4]; inputs[0] = weight.GetHandle(); inputs[1] = grad.GetHandle(); int num_outputs = 1; NDArrayHandle output = weight.GetHandle(); NDArrayHandle *outputs = &output; inputs[2] = mean_[index]->GetHandle(); inputs[3] = var_[index]->GetHandle(); MXImperativeInvoke(update_handle_, 4, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data(), nullptr); } inline void AdamOptimizer::CreateState_(int index, NDArray weight) { mean_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *mean_[index] = 0; var_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *var_[index] = 0; } inline AdaGradOptimizer::AdaGradOptimizer(unsigned begin_num_update) : Optimizer(begin_num_update) { SetParam("eps", 1e-7); } inline std::string AdaGradOptimizer::GetType() const { return "adagrad"; } inline void AdaGradOptimizer::Update(int index, NDArray weight, NDArray grad) { if (history_.count(index) == 0) { CreateState_(index, weight); } float eps = dmlc::stof(params_["eps"]); float lr = GetLR_(index); float wd = GetWD_(index); UpdateCount_(index); if (params_.count("rescale_grad") > 0) { grad *= dmlc::stof(params_["rescale_grad"]); } if (params_.count("clip_gradient") > 0) { _clip(grad, dmlc::stof(params_["clip_gradient"])); } auto& history = *history_[index]; history += grad * grad; weight -= (grad / _sqrt(history + eps) + weight * wd) * lr; } inline AdaGradOptimizer::~AdaGradOptimizer() { for (auto& it : history_) { delete it.second; } } inline void AdaGradOptimizer::CreateState_(int index, NDArray weight) { history_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *history_[index] = 0; } inline AdaDeltaOptimizer::AdaDeltaOptimizer(unsigned begin_num_update) : Optimizer(begin_num_update) { SetParam("rho", 0.90f); SetParam("epsilon", 1e-5); } inline std::string AdaDeltaOptimizer::GetType() const { return "adadelta"; } inline void AdaDeltaOptimizer::Update(int index, NDArray weight, NDArray grad) { if (acc_g_.count(index) == 0) { CreateState_(index, weight); } float rho = dmlc::stof(params_["rho"]); float epsilon = dmlc::stof(params_["epsilon"]); float wd = GetWD_(index); UpdateCount_(index); if (params_.count("rescale_grad") > 0) { grad *= dmlc::stof(params_["rescale_grad"]); } if (params_.count("clip_gradient") > 0) { _clip(grad, dmlc::stof(params_["clip_gradient"])); } auto& acc_g = *acc_g_[index]; auto& acc_delta = *acc_delta_[index]; acc_g *= rho; acc_g += grad * grad * (1.0f - rho); auto delta = _sqrt(acc_delta + epsilon) / _sqrt(acc_g + epsilon) * grad; acc_delta *= rho; acc_delta += delta * delta * (1.0f - rho); weight *= 1.0f - wd; weight -= delta; } inline AdaDeltaOptimizer::~AdaDeltaOptimizer() { for (auto& it : acc_g_) { delete it.second; } for (auto& it : acc_delta_) { delete it.second; } } inline void AdaDeltaOptimizer::CreateState_(int index, NDArray weight) { acc_g_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *acc_g_[index] = 0; acc_delta_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *acc_delta_[index] = 0; } } // namespace cpp } // namespace mxnet #endif // MXNET_CPP_OPTIMIZER_HPP_