/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file operator.h * \brief definition of operator * \author Chuntao Hong, Zhang Chen */ #ifndef MXNET_CPP_OPERATOR_H_ #define MXNET_CPP_OPERATOR_H_ #include #include #include #include "mxnet-cpp/base.h" #include "mxnet-cpp/op_map.h" #include "mxnet-cpp/symbol.h" namespace mxnet { namespace cpp { class Mxnet; /*! * \brief Operator interface */ class Operator { public: /*! * \brief Operator constructor * \param operator_name type of the operator */ explicit Operator(const std::string& operator_name); Operator& operator=(const Operator& rhs); /*! * \brief set config parameters * \param name name of the config parameter * \param value value of the config parameter * \return reference of self */ template Operator& SetParam(const std::string& name, const T& value) { std::string value_str; std::stringstream ss; ss << value; ss >> value_str; params_[name] = value_str; return *this; } /*! * \brief set config parameters from positional inputs * \param pos the position of parameter * \param value value of the config parameter * \return reference of self */ template Operator& SetParam(int pos, const T& value) { std::string value_str; std::stringstream ss; ss << value; ss >> value_str; params_[arg_names_[pos]] = value_str; return *this; } /*! * \brief add an input symbol * \param name name of the input symbol * \param symbol the input symbol * \return reference of self */ Operator& SetInput(const std::string& name, const Symbol& symbol); /*! * \brief add an input symbol * \param symbol the input symbol */ template void PushInput(const Symbol& symbol) { input_symbols_.push_back(symbol.GetHandle()); } /*! * \brief add input symbols * \return reference of self */ Operator& operator()() { return *this; } /*! * \brief add input symbols * \param symbol the input symbol * \return reference of self */ Operator& operator()(const Symbol& symbol) { input_symbols_.push_back(symbol.GetHandle()); return *this; } /*! * \brief add a list of input symbols * \param symbols the vector of the input symbols * \return reference of self */ Operator& operator()(const std::vector& symbols) { for (auto& s : symbols) { input_symbols_.push_back(s.GetHandle()); } return *this; } /*! * \brief create a Symbol from the current operator * \param name the name of the operator * \return the operator Symbol */ Symbol CreateSymbol(const std::string& name = ""); /*! * \brief add an input ndarray * \param name name of the input ndarray * \param ndarray the input ndarray * \return reference of self */ Operator& SetInput(const std::string& name, const NDArray& ndarray); /*! * \brief add an input ndarray * \param ndarray the input ndarray */ template Operator& PushInput(const NDArray& ndarray) { input_ndarrays_.push_back(ndarray.GetHandle()); return *this; } /*! * \brief add positional inputs */ template Operator& PushInput(const T& t, Args... args) { SetParam(N, t); PushInput(args...); return *this; } /*! * \brief add the last positional input */ template Operator& PushInput(const T& t) { SetParam(N, t); return *this; } /*! * \brief add input ndarrays * \param ndarray the input ndarray * \return reference of self */ Operator& operator()(const NDArray& ndarray) { input_ndarrays_.push_back(ndarray.GetHandle()); return *this; } /*! * \brief add a list of input ndarrays * \param ndarrays the vector of the input ndarrays * \return reference of self */ Operator& operator()(const std::vector& ndarrays) { for (auto& s : ndarrays) { input_ndarrays_.push_back(s.GetHandle()); } return *this; } /*! * \brief add input ndarrays * \return reference of self */ template Operator& operator()(Args... args) { PushInput(args...); return *this; } std::vector Invoke(); void Invoke(NDArray& output); void Invoke(std::vector& outputs); private: std::map params_desc_; bool variable_params_ = false; std::map params_; std::vector input_symbols_; std::vector input_ndarrays_; std::vector input_keys_; std::vector arg_names_; AtomicSymbolCreator handle_; static OpMap*& op_map(); }; } // namespace cpp } // namespace mxnet #endif // MXNET_CPP_OPERATOR_H_