/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file optimizer.h * \brief definition of optimizer * \author Chuntao Hong, Zhang Chen */ #ifndef MXNET_CPP_OPTIMIZER_H_ #define MXNET_CPP_OPTIMIZER_H_ #include #include #include #include #include #include #include "mxnet-cpp/base.h" #include "dmlc/logging.h" #include "mxnet-cpp/ndarray.h" #include "mxnet-cpp/op_map.h" #include "mxnet-cpp/lr_scheduler.h" namespace mxnet { namespace cpp { /*! * \brief Optimizer interface */ class Optimizer { public: /*! * \brief constructor * \param beign_num_update The initial number of updates */ explicit Optimizer(unsigned begin_num_update); /*! * \brief get optimizer type * \return string of optimizer type */ virtual std::string GetType() const = 0; /*! * \brief destructor */ virtual ~Optimizer(); /*! * \brief set config parameters * \param name name of the config parameter * \param value value of the config parameter * \return reference of self */ template Optimizer* SetParam(const std::string& name, const T& value) { std::string value_str; std::stringstream ss; ss << value; ss >> value_str; params_[name] = value_str; return this; } /*! * \bried set the lr scheduler * \param lrScheduler lr scheduler used for this optimizer * \return reference if self */ Optimizer* SetLRScheduler(std::unique_ptr lrScheduler) { CHECK(lrScheduler); lrScheduler_ = std::move(lrScheduler); lrScheduler_->SetLR(dmlc::stof(params_["lr"])); return this; } /*! * \brief Update a weight with gradient. * \param index the unique index for the weight. * \param weight the weight to update. * \param grad gradient for the weight. */ virtual void Update(int index, NDArray weight, NDArray grad) = 0; // TODO(zhangcheng-qinyinghua) // implement Update a list of arrays, maybe in the form of map // void Update(int index, std::vector weights, std::vector // grad, mx_float lr); /*! * \brief Serialize the optimizer parameters to a string. * \return serialization */ std::string Serialize() const; protected: std::map params_; static OpMap*& op_map(); const std::vector GetParamKeys_() const; const std::vector GetParamValues_() const; std::map count_; unsigned begin_num_update_, num_update_; unsigned UpdateCount_(int index); float GetLR_(int index); float GetWD_(int index); virtual void CreateState_(int index, NDArray weight); std::unique_ptr lrScheduler_ = nullptr; }; typedef std::function OptimizerCreator; class OptimizerRegistry { public: static Optimizer* Find(const std::string& name); static int __REGISTER__(const std::string& name, OptimizerCreator creator); private: static std::map& cmap(); OptimizerRegistry() = delete; ~OptimizerRegistry() = delete; }; #define MXNETCPP_REGISTER_OPTIMIZER(Name, OptimizerType) \ OptimizerRegistry::__REGISTER__(#Name, []() { return new OptimizerType(); }) class SGDOptimizer : public Optimizer { public: explicit SGDOptimizer(unsigned begin_num_update = 0); std::string GetType() const override; void Update(int index, NDArray weight, NDArray grad) override; private: virtual ~SGDOptimizer(); void CreateState_(int index, NDArray weight) override; std::map states_; AtomicSymbolCreator update_handle_; AtomicSymbolCreator mom_update_handle_; }; class SignumOptimizer : public Optimizer { public: explicit SignumOptimizer(unsigned begin_num_update = 0); std::string GetType() const override; void Update(int index, NDArray weight, NDArray grad) override; private: virtual ~SignumOptimizer(); void CreateState_(int index, NDArray weight) override; std::map states_; AtomicSymbolCreator update_handle_; AtomicSymbolCreator mom_update_handle_; }; class RMSPropOptimizer : public Optimizer { public: explicit RMSPropOptimizer(unsigned begin_num_update = 0); std::string GetType() const override; void Update(int index, NDArray weight, NDArray grad) override; private: virtual ~RMSPropOptimizer(); void CreateState_(int index, NDArray weight) override; std::map n_, g_, delta_; AtomicSymbolCreator update_handle_; AtomicSymbolCreator alex_update_handle_; }; class AdamOptimizer : public Optimizer { public: explicit AdamOptimizer(unsigned begin_num_update = 0); std::string GetType() const override; void Update(int index, NDArray weight, NDArray grad) override; private: virtual ~AdamOptimizer(); void CreateState_(int index, NDArray weight) override; std::map mean_; std::map var_; AtomicSymbolCreator update_handle_; }; class AdaGradOptimizer : public Optimizer { public: explicit AdaGradOptimizer(unsigned begin_num_update = 0); std::string GetType() const override; void Update(int index, NDArray weight, NDArray grad) override; private: virtual ~AdaGradOptimizer(); void CreateState_(int index, NDArray weight) override; std::map history_; }; class AdaDeltaOptimizer : public Optimizer { public: explicit AdaDeltaOptimizer(unsigned begin_num_update = 0); std::string GetType() const override; void Update(int index, NDArray weight, NDArray grad) override; private: virtual ~AdaDeltaOptimizer(); void CreateState_(int index, NDArray weight) override; std::map acc_g_, acc_delta_; }; } // namespace cpp } // namespace mxnet #endif // MXNET_CPP_OPTIMIZER_H_