/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file torch_function.h * \brief Torch interface. * \author Junyuan Xie */ #ifndef PLUGIN_TORCH_TORCH_FUNCTION_H_ #define PLUGIN_TORCH_TORCH_FUNCTION_H_ #include "./torch_base.h" #include #include #include #include #include #include #include #include namespace mxnet { template void TorchRunOp(std::vector arr_in, std::vector arr_out, const std::map& param, RunContext ctx) { TorchState* torchState = TorchState::ThreadSharedLuaState(); torchState->SetStream(ctx.get_stream()); lua_State* L = torchState->L; lua_getglobal(L, "torch"); lua_getfield(L, -1, OP::fname); int idx = 0; std::vector arr(arr_out.begin(), arr_out.end()); arr.insert(arr.end(), arr_in.begin(), arr_in.end()); std::string format = param.at("format"); std::istringstream args(param.at("args")); for (size_t i = 0; i < format.size(); ++i) { std::string val; std::getline(args, val, ','); switch (format[i]) { case 'n': { CHECK(idx < arr.size()) << "Too few NDArray arguments for Torch." << OP::fname; luaT_pushudata(L, TorchTensor::TBlobToTHTensor(torchState, arr[idx].data()), TorchTensor::TensorType(arr[idx].data())); idx++; break; } case 'i': lua_pushinteger(L, std::stoi(val)); break; case 'f': lua_pushnumber(L, std::stof(val)); break; case 's': lua_pushstring(L, val.c_str()); break; case 'b': lua_pushboolean(L, std::stoi(val)); break; default: LOG(FATAL) << "Unknown argument type " << format[i] << " for Torch." << OP::fname; } } CHECK_EQ(lua_pcall(L, format.size(), 0, 0), 0) << "Lua Error: " << lua_tostring(L, -1); } template void TorchOp(NDArray** u, real_t* s, NDArray** out, const std::map& param) { std::vector shapes = OP::GetShape(u, param); CHECK_EQ(shapes.size(), OP::num_outputs) << "Too many output shapes for TorchOp " << OP::fname; Context ctx; int type_flag; if (OP::num_inputs) { ctx = u[0]->ctx(); type_flag = u[0]->dtype(); for (int i = 0; i < OP::num_inputs; ++i) { CHECK_EQ(ctx, u[i]->ctx()) << "Context of all oprands must be the same."; CHECK_EQ(type_flag, u[i]->dtype()) << "Data type of all oprands must be the same."; } } else { CHECK(param.count("ctx")) << "Must provide keyword argument ctx for TorchOp with 0 inputs"; std::string str_ctx(param.at("ctx")); int id; char tmp[4]; sscanf(str_ctx.c_str(), "%3s(%d)", tmp, &id); std::string dev(tmp); if (dev == "cpu") { ctx = Context::Create(Context::kCPU, id); } else if (dev == "gpu") { ctx = Context::Create(Context::kGPU, id); } else { LOG(FATAL) << "Unknown device type " << dev; } if (param.count("dtype")) { std::stringstream str_dtype(param.at("dtype")); str_dtype >> type_flag; } else { type_flag = mshadow::default_type_flag; } } std::vector arr_in, arr_out; std::vector var_in, var_out, var_const; for (int i = 0; i < OP::num_inputs; ++i) { arr_in.push_back(*(u[i])); var_in.push_back(u[i]->var()); } for (int i = 0; i < OP::num_outputs; ++i) { if (out[i]->is_none()) { *(out[i]) = NDArray(shapes[i], ctx, false, type_flag); } arr_out.push_back(*(out[i])); var_out.push_back(out[i]->var()); } std::sort(var_in.begin(), var_in.end()); var_in.resize(std::unique(var_in.begin(), var_in.end()) - var_in.begin()); std::sort(var_out.begin(), var_out.end()); var_out.resize(std::unique(var_out.begin(), var_out.end()) - var_out.begin()); std::set_difference(var_in.begin(), var_in.end(), var_out.begin(), var_out.end(), std::inserter(var_const, var_const.begin())); switch (ctx.dev_mask()) { case mshadow::cpu::kDevMask: { Engine::Get()->PushSync( [arr_in, arr_out, param](RunContext rctx) { TorchRunOp(arr_in, arr_out, param, rctx); }, ctx, var_const, var_out); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { Engine::Get()->PushSync( [arr_in, arr_out, param](RunContext rctx) { TorchRunOp(arr_in, arr_out, param, rctx); }, ctx, var_const, var_out); break; } #endif default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } } struct TorchFirstShape { static std::vector GetShape(NDArray** u, const std::map& param) { return {u[0]->shape()}; } }; struct TorchConstructorShape { static std::vector GetShape(NDArray** u, const std::map& param) { std::vector shape; std::string format = param.at("format"); std::istringstream args(param.at("args")); std::string val; std::getline(args, val, ','); CHECK_LE(format.size(), 5) << "Only support up to 4 dimensions."; for (size_t i = 1; i < format.size(); ++i) { CHECK_EQ(format[i], 'i') << "Only take integer arguments."; std::getline(args, val, ','); shape.push_back(std::stoi(val)); } mshadow::TShape tshape(shape.begin(), shape.end()); return {tshape}; } static const int num_inputs = 0; static const int num_outputs = 1; }; #define MXNET_REGISTER_TORCH_FUN(name, OP) \ MXNET_REGISTER_NDARRAY_FUN(name) \ .set_function(TorchOp) \ .set_num_use_vars(OP::num_inputs) \ .set_num_mutate_vars(OP::num_outputs) \ .set_type_mask(kAcceptEmptyMutateTarget) #define MXNET_REGISTER_TORCH_UNARY_FUN(name, func) \ struct TorchUnaryOpDesc_##name##_##func : public TorchFirstShape { \ static constexpr const char* fname = #func; \ static const int num_inputs = 1; \ static const int num_outputs = 1; \ }; \ MXNET_REGISTER_TORCH_FUN(name, TorchUnaryOpDesc_##name##_##func) \ .add_argument("x", "NDArray", "Input NDArray") #define MXNET_REGISTER_TORCH_BINARY_FUN(name, func) \ struct TorchBinaryOpDesc_##name##_##func : public TorchFirstShape { \ static constexpr const char* fname = #func; \ static const int num_inputs = 2; \ static const int num_outputs = 1; \ }; \ MXNET_REGISTER_TORCH_FUN(name, TorchBinaryOpDesc_##name##_##func) #define MXNET_REGISTER_TORCH_BINARY_FUN_WITH_ARG(name, func) \ MXNET_REGISTER_TORCH_BINARY_FUN(name, func) \ .add_argument("x1", "NDArray", "First Input NDArray") \ .add_argument("x2", "NDArray", "Second Input NDArray") #define MXNET_REGISTER_TORCH_TENARY_FUN(name, func) \ struct TorchTenaryOpDesc_##name##_##func : public TorchFirstShape { \ static constexpr const char* fname = #func; \ static const int num_inputs = 3; \ static const int num_outputs = 1; \ }; \ MXNET_REGISTER_TORCH_FUN(name, TorchTenaryOpDesc_##name##_##func) #define MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(name, func) \ struct TorchConstructorOpDesc_##name##_##func : public TorchConstructorShape { \ static constexpr const char* fname = #func; \ }; \ MXNET_REGISTER_TORCH_FUN(name, TorchConstructorOpDesc_##name##_##func) } // namespace mxnet #endif // PLUGIN_TORCH_TORCH_FUNCTION_H_