/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file symbol.hpp * \brief implementation of the symbol * \author Zhang Chen, Chuntao Hong */ #ifndef MXNET_CPP_SYMBOL_HPP_ #define MXNET_CPP_SYMBOL_HPP_ #include #include #include #include #include "dmlc/logging.h" #include "mxnet-cpp/symbol.h" #include "mxnet-cpp/op_suppl.h" namespace mxnet { namespace cpp { inline OpMap*& Symbol::op_map() { static OpMap* op_map_ = new OpMap(); return op_map_; } inline Symbol::Symbol(SymbolHandle handle) { blob_ptr_ = std::make_shared(handle); } inline Symbol::Symbol(const char *name) { SymbolHandle handle; CHECK_EQ(MXSymbolCreateVariable(name, &(handle)), 0); blob_ptr_ = std::make_shared(handle); } inline Symbol::Symbol(const std::string &name) : Symbol(name.c_str()) {} inline Symbol Symbol::Variable(const std::string &name) { return Symbol(name); } inline Symbol Symbol::operator+(const Symbol &rhs) const { return _Plus(*this, rhs); } inline Symbol Symbol::operator-(const Symbol &rhs) const { return _Minus(*this, rhs); } inline Symbol Symbol::operator*(const Symbol &rhs) const { return _Mul(*this, rhs); } inline Symbol Symbol::operator/(const Symbol &rhs) const { return _Div(*this, rhs); } inline Symbol Symbol::operator%(const Symbol &rhs) const { return _Mod(*this, rhs); } inline Symbol Symbol::operator+(mx_float scalar) const { return _PlusScalar(*this, scalar); } inline Symbol Symbol::operator-(mx_float scalar) const { return _MinusScalar(*this, scalar); } inline Symbol Symbol::operator*(mx_float scalar) const { return _MulScalar(*this, scalar); } inline Symbol Symbol::operator/(mx_float scalar) const { return _DivScalar(*this, scalar); } inline Symbol Symbol::operator%(mx_float scalar) const { return _ModScalar(*this, scalar); } inline Symbol Symbol::operator[](int index) { SymbolHandle out; MXSymbolGetOutput(GetHandle(), index, &out); return Symbol(out); } inline Symbol Symbol::operator[](const std::string &index) { auto outputs = ListOutputs(); for (mx_uint i = 0; i < outputs.size(); ++i) { if (outputs[i] == index) { return (*this)[i]; } } LOG(FATAL) << "Cannot find output that matches name " << index; return (*this)[0]; } inline Symbol Symbol::Group(const std::vector &symbols) { SymbolHandle out; std::vector handle_list; for (const auto &t : symbols) { handle_list.push_back(t.GetHandle()); } MXSymbolCreateGroup(handle_list.size(), handle_list.data(), &out); return Symbol(out); } inline Symbol Symbol::Load(const std::string &file_name) { op_map(); SymbolHandle handle; CHECK_EQ(MXSymbolCreateFromFile(file_name.c_str(), &(handle)), 0); return Symbol(handle); } inline Symbol Symbol::LoadJSON(const std::string &json_str) { op_map(); SymbolHandle handle; CHECK_EQ(MXSymbolCreateFromJSON(json_str.c_str(), &(handle)), 0); return Symbol(handle); } inline void Symbol::Save(const std::string &file_name) const { CHECK_EQ(MXSymbolSaveToFile(GetHandle(), file_name.c_str()), 0); } inline std::string Symbol::ToJSON() const { const char *out_json; CHECK_EQ(MXSymbolSaveToJSON(GetHandle(), &out_json), 0); return std::string(out_json); } inline Symbol Symbol::GetInternals() const { SymbolHandle handle; CHECK_EQ(MXSymbolGetInternals(GetHandle(), &handle), 0); return Symbol(handle); } inline Symbol::Symbol(const std::string &operator_name, const std::string &name, std::vector input_keys, std::vector input_values, std::vector config_keys, std::vector config_values) { SymbolHandle handle; AtomicSymbolCreator creator = op_map()->GetSymbolCreator(operator_name); MXSymbolCreateAtomicSymbol(creator, config_keys.size(), config_keys.data(), config_values.data(), &handle); MXSymbolCompose(handle, operator_name.c_str(), input_keys.size(), input_keys.data(), input_values.data()); blob_ptr_ = std::make_shared(handle); } inline Symbol Symbol::Copy() const { SymbolHandle handle; CHECK_EQ(MXSymbolCopy(GetHandle(), &handle), 0); return Symbol(handle); } inline std::vector Symbol::ListArguments() const { std::vector ret; mx_uint size; const char **sarr; MXSymbolListArguments(GetHandle(), &size, &sarr); for (mx_uint i = 0; i < size; ++i) { ret.push_back(std::string(sarr[i])); } return ret; } inline std::vector Symbol::ListInputs() const { std::vector ret; mx_uint size; const char **sarr; NNSymbolListInputNames(GetHandle(), 0, &size, &sarr); for (mx_uint i = 0; i < size; ++i) { ret.push_back(std::string(sarr[i])); } return ret; } inline std::vector Symbol::ListOutputs() const { std::vector ret; mx_uint size; const char **sarr; MXSymbolListOutputs(GetHandle(), &size, &sarr); for (mx_uint i = 0; i < size; ++i) { ret.push_back(std::string(sarr[i])); } return ret; } inline std::vector Symbol::ListAuxiliaryStates() const { std::vector ret; mx_uint size; const char **sarr; MXSymbolListAuxiliaryStates(GetHandle(), &size, &sarr); for (mx_uint i = 0; i < size; ++i) { ret.push_back(std::string(sarr[i])); } return ret; } inline std::map Symbol::ListAttributes() const { mx_uint size; const char** pairs; CHECK_EQ(MXSymbolListAttrShallow(GetHandle(), &size, &pairs), 0); std::map attributes; for (mx_uint i = 0; i < size; ++i) { // pairs is 2 * size with key, value pairs according to // https://github.com/apache/incubator-mxnet/blob/master/include/mxnet/c_api.h#L1428 attributes[pairs[2 * i]] = pairs[2 * i + 1]; } return attributes; } inline void Symbol::SetAttribute(const std::string &key, const std::string &value) { CHECK_EQ(MXSymbolSetAttr(GetHandle(), key.c_str(), value.c_str()), 0); } inline void Symbol::SetAttributes(const std::map &attrs) { for (const auto& kv : attrs) { SetAttribute(kv.first, kv.second); } } inline mx_uint Symbol::GetNumOutputs() const { mx_uint numOutputs; CHECK_EQ(MXSymbolGetNumOutputs(GetHandle(), &numOutputs), 0); return numOutputs; } inline mxnet::cpp::Symbol Symbol::GetBackendSymbol(const std::string &backendName) const { SymbolHandle symbolHandle; CHECK_EQ(MXGenBackendSubgraph(GetHandle(), backendName.c_str(), &symbolHandle), 0); return mxnet::cpp::Symbol(symbolHandle); } inline std::string Symbol::GetName() const { int success; const char* out_name; CHECK_EQ(MXSymbolGetName(GetHandle(), &out_name, &success), 0); CHECK_EQ(success, 1); return std::string(out_name); } inline void Symbol::InferShape( const std::map > &arg_shapes, std::vector > *in_shape, std::vector > *aux_shape, std::vector > *out_shape) const { std::vector keys; std::vector arg_ind_ptr; std::vector arg_shape_data; for (const auto &arg : arg_shapes) { keys.push_back(arg.first.c_str()); arg_ind_ptr.push_back(arg_shape_data.size()); for (auto i : arg.second) { arg_shape_data.push_back(i); } } arg_ind_ptr.push_back(arg_shape_data.size()); mx_uint in_shape_size; const int *in_shape_ndim; const int **in_shape_data; mx_uint out_shape_size; const int *out_shape_ndim; const int **out_shape_data; mx_uint aux_shape_size; const int *aux_shape_ndim; const int **aux_shape_data; int complete; CHECK_EQ(MXSymbolInferShape(GetHandle(), keys.size(), keys.data(), arg_ind_ptr.data(), arg_shape_data.data(), &in_shape_size, &in_shape_ndim, &in_shape_data, &out_shape_size, &out_shape_ndim, &out_shape_data, &aux_shape_size, &aux_shape_ndim, &aux_shape_data, &complete), 0); if (complete) { for (mx_uint i = 0; i < in_shape_size; ++i) { in_shape->push_back(std::vector()); for (int j = 0; j < in_shape_ndim[i]; ++j) { (*in_shape)[i].push_back(in_shape_data[i][j]); } } for (mx_uint i = 0; i < aux_shape_size; ++i) { aux_shape->push_back(std::vector()); for (int j = 0; j < aux_shape_ndim[i]; ++j) { (*aux_shape)[i].push_back(aux_shape_data[i][j]); } } for (mx_uint i = 0; i < out_shape_size; ++i) { out_shape->push_back(std::vector()); for (int j = 0; j < out_shape_ndim[i]; ++j) { (*out_shape)[i].push_back(out_shape_data[i][j]); } } } } inline void Symbol::InferExecutorArrays( const Context &context, std::vector *arg_arrays, std::vector *grad_arrays, std::vector *grad_reqs, std::vector *aux_arrays, const std::map &args_map, const std::map &arg_grad_store, const std::map &grad_req_type, const std::map &aux_map) const { const auto arg_name_list = ListArguments(); std::vector > in_shapes, aux_shapes, out_shapes; std::map > arg_shapes; for (const auto &arg_name : arg_name_list) { auto iter = args_map.find(arg_name); if (iter != args_map.end()) { arg_shapes[arg_name] = iter->second.GetShape(); } } InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes); for (size_t i = 0; i < in_shapes.size(); ++i) { const auto &shape = in_shapes[i]; const auto &arg_name = arg_name_list[i]; auto iter_arg = args_map.find(arg_name); if (iter_arg != args_map.end()) { arg_arrays->push_back(iter_arg->second); } else { arg_arrays->push_back(NDArray(shape, context, false)); NDArray::SampleGaussian(0, 1, &arg_arrays->back()); } auto iter_grad = arg_grad_store.find(arg_name); if (iter_grad != arg_grad_store.end()) { grad_arrays->push_back(iter_grad->second); } else { grad_arrays->push_back(NDArray(shape, context, false)); } auto iter_req = grad_req_type.find(arg_name); if (iter_req != grad_req_type.end()) { grad_reqs->push_back(iter_req->second); } else if (arg_name.rfind("data") != std::string::npos || arg_name.rfind("label") != std::string::npos) { grad_reqs->push_back(OpReqType::kNullOp); } else { grad_reqs->push_back(OpReqType::kWriteTo); } } const auto aux_name_list = ListAuxiliaryStates(); for (size_t i = 0; i < aux_shapes.size(); ++i) { const auto &shape = aux_shapes[i]; const auto &aux_name = aux_name_list[i]; auto iter_aux = aux_map.find(aux_name); if (iter_aux != aux_map.end()) { aux_arrays->push_back(iter_aux->second); } else { aux_arrays->push_back(NDArray(shape, context, false)); NDArray::SampleGaussian(0, 1, &aux_arrays->back()); } } } inline void Symbol::InferArgsMap( const Context &context, std::map *args_map, const std::map &known_args) const { const auto arg_name_list = ListArguments(); std::vector > in_shapes, aux_shapes, out_shapes; std::map > arg_shapes; for (const auto &arg_name : arg_name_list) { auto iter = known_args.find(arg_name); if (iter != known_args.end()) { arg_shapes[arg_name] = iter->second.GetShape(); } } InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes); for (size_t i = 0; i < in_shapes.size(); ++i) { const auto &shape = in_shapes[i]; const auto &arg_name = arg_name_list[i]; auto iter_arg = known_args.find(arg_name); if (iter_arg != known_args.end()) { (*args_map)[arg_name] = iter_arg->second; } else { (*args_map)[arg_name] = NDArray(shape, context, false); NDArray::SampleGaussian(0, 1, &(*args_map)[arg_name]); } } } inline Executor *Symbol::SimpleBind( const Context &context, const std::map &args_map, const std::map &arg_grad_store, const std::map &grad_req_type, const std::map &aux_map) { std::vector arg_arrays; std::vector grad_arrays; std::vector grad_reqs; std::vector aux_arrays; InferExecutorArrays(context, &arg_arrays, &grad_arrays, &grad_reqs, &aux_arrays, args_map, arg_grad_store, grad_req_type, aux_map); return new Executor(*this, context, arg_arrays, grad_arrays, grad_reqs, aux_arrays); } inline Executor *Symbol::Bind(const Context &context, const std::vector &arg_arrays, const std::vector &grad_arrays, const std::vector &grad_reqs, const std::vector &aux_arrays, const std::map &group_to_ctx, Executor *shared_exec) { return new Executor(*this, context, arg_arrays, grad_arrays, grad_reqs, aux_arrays, group_to_ctx, shared_exec); } inline Symbol operator+(mx_float lhs, const Symbol &rhs) { return rhs + lhs; } inline Symbol operator-(mx_float lhs, const Symbol &rhs) { return mxnet::cpp::_RMinusScalar(lhs, rhs); } inline Symbol operator*(mx_float lhs, const Symbol &rhs) { return rhs * lhs; } inline Symbol operator/(mx_float lhs, const Symbol &rhs) { return mxnet::cpp::_RDivScalar(lhs, rhs); } inline Symbol operator%(mx_float lhs, const Symbol &rhs) { return mxnet::cpp::_RModScalar(lhs, rhs); } } // namespace cpp } // namespace mxnet #endif // MXNET_CPP_SYMBOL_HPP_