/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file torch_base.cc * \brief torch_state * \author Junyuan Xie */ #include "./torch_function.h" namespace mxnet { // Construction or extraction functions MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_eye, eye); MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_ones, ones); MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_rand, rand); MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_randn, randn); MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_randperm, randperm); MXNET_REGISTER_TORCH_UNARY_FUN(_th_tril, tril); MXNET_REGISTER_TORCH_UNARY_FUN(_th_triu, triu); MXNET_REGISTER_TORCH_CONSTRUCTOR_FUN(_th_zeros, zeros); // Element-wise Mathematical Operations MXNET_REGISTER_TORCH_UNARY_FUN(_th_abs, abs); MXNET_REGISTER_TORCH_UNARY_FUN(_th_sign, sign); MXNET_REGISTER_TORCH_UNARY_FUN(_th_acos, acos); MXNET_REGISTER_TORCH_UNARY_FUN(_th_asin, asin); MXNET_REGISTER_TORCH_UNARY_FUN(_th_atan, atan); MXNET_REGISTER_TORCH_UNARY_FUN(_th_ceil, ceil); MXNET_REGISTER_TORCH_UNARY_FUN(_th_cos, cos); MXNET_REGISTER_TORCH_UNARY_FUN(_th_cosh, cosh); MXNET_REGISTER_TORCH_UNARY_FUN(_th_exp, exp); MXNET_REGISTER_TORCH_UNARY_FUN(_th_floor, floor); MXNET_REGISTER_TORCH_UNARY_FUN(_th_log, log); MXNET_REGISTER_TORCH_UNARY_FUN(_th_log1p, log1p); MXNET_REGISTER_TORCH_UNARY_FUN(_th_pow, pow) .add_argument("n", "float", "pow(x, n) returns x^n, element-wise. " "pow(n, x) returns n^x, element-wise."); MXNET_REGISTER_TORCH_UNARY_FUN(_th_round, round); MXNET_REGISTER_TORCH_UNARY_FUN(_th_sin, sin); MXNET_REGISTER_TORCH_UNARY_FUN(_th_sinh, sinh); MXNET_REGISTER_TORCH_UNARY_FUN(_th_sqrt, sqrt); MXNET_REGISTER_TORCH_UNARY_FUN(_th_tan, tan); MXNET_REGISTER_TORCH_UNARY_FUN(_th_tanh, tanh); // Basic operations MXNET_REGISTER_TORCH_UNARY_FUN(_th_add_scalar, add) .add_argument("value", "float", "Add value to all elements in x"); MXNET_REGISTER_TORCH_BINARY_FUN_WITH_ARG(_th_add, add); MXNET_REGISTER_TORCH_BINARY_FUN(_th_add_axpy, add); // MXNET_REGISTER_TORCH_UNARY_FUN(_th_csub_scalar, csub); // MXNET_REGISTER_TORCH_BINARY_FUN_WITH_ARG(_th_csub, csub); MXNET_REGISTER_TORCH_UNARY_FUN(_th_mul_scalar, mul) .add_argument("value", "float", "Multiply value to all elements in x"); MXNET_REGISTER_TORCH_BINARY_FUN_WITH_ARG(_th_cmul, cmul); MXNET_REGISTER_TORCH_UNARY_FUN(_th_clamp, clamp); MXNET_REGISTER_TORCH_BINARY_FUN_WITH_ARG(_th_cpow, cpow); MXNET_REGISTER_TORCH_TENARY_FUN(_th_addcmul, addcmul); MXNET_REGISTER_TORCH_UNARY_FUN(_th_div_scalar, div) .add_argument("value", "float", "Divide all elements in x by value"); MXNET_REGISTER_TORCH_BINARY_FUN_WITH_ARG(_th_cdiv, cdiv); MXNET_REGISTER_TORCH_TENARY_FUN(_th_addcdiv, addcdiv); MXNET_REGISTER_TORCH_TENARY_FUN(_th_addmv, addmv); MXNET_REGISTER_TORCH_TENARY_FUN(_th_addr, addr); MXNET_REGISTER_TORCH_TENARY_FUN(_th_addmm, addmm); MXNET_REGISTER_TORCH_TENARY_FUN(_th_addbmm, addbmm); MXNET_REGISTER_TORCH_TENARY_FUN(_th_baddbmm, baddbmm); struct TorchMMShape { static std::vector GetShape(NDArray** u, const std::map& param) { CHECK_EQ(u[0]->shape().ndim(), 2); CHECK_EQ(u[1]->shape().ndim(), 2); CHECK_EQ(u[0]->shape()[1], u[1]->shape()[0]); index_t shape[] = {u[0]->shape()[0], u[1]->shape()[1]}; mshadow::TShape tshape(shape, shape + 2); return {tshape}; } static constexpr const char* fname = "mm"; static const int num_inputs = 2; static const int num_outputs = 1; }; MXNET_REGISTER_TORCH_FUN(_th_mm, TorchMMShape); struct TorchMVShape { static std::vector GetShape(NDArray** u, const std::map& param) { CHECK_EQ(u[0]->shape().ndim(), 2); CHECK_EQ(u[1]->shape().ndim(), 1); CHECK_EQ(u[0]->shape()[1], u[1]->shape()[0]); index_t shape[] = {u[0]->shape()[0]}; mshadow::TShape tshape(shape, shape + 1); return {tshape}; } static constexpr const char* fname = "mv"; static const int num_inputs = 2; static const int num_outputs = 1; }; MXNET_REGISTER_TORCH_FUN(_th_mv, TorchMVShape); struct TorchBMMShape { static std::vector GetShape(NDArray** u, const std::map& param) { CHECK_EQ(u[0]->shape().ndim(), 3); CHECK_EQ(u[1]->shape().ndim(), 3); CHECK_EQ(u[0]->shape()[0], u[1]->shape()[0]); CHECK_EQ(u[0]->shape()[2], u[1]->shape()[1]); index_t shape[] = {u[0]->shape()[1], u[1]->shape()[2]}; mshadow::TShape tshape(shape, shape + 2); return {tshape}; } static constexpr const char* fname = "bmm"; static const int num_inputs = 2; static const int num_outputs = 1; }; MXNET_REGISTER_TORCH_FUN(_th_bmm, TorchBMMShape); struct TorchGERShape { static std::vector GetShape(NDArray** u, const std::map& param) { CHECK_EQ(u[0]->shape().ndim(), 1); CHECK_EQ(u[1]->shape().ndim(), 1); index_t shape[] = {u[0]->shape()[0], u[1]->shape()[0]}; mshadow::TShape tshape(shape, shape + 2); return {tshape}; } static constexpr const char* fname = "ger"; static const int num_inputs = 2; static const int num_outputs = 1; }; MXNET_REGISTER_TORCH_FUN(_th_ger, TorchGERShape); } // namespace mxnet