/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2017 by Contributors * \file monitor.h * \brief monitor definition * \author Xin Li */ #ifndef MXNET_CPP_MONITOR_H_ #define MXNET_CPP_MONITOR_H_ #include #include #include #include #include #include #include #include "mxnet-cpp/base.h" #include "mxnet-cpp/ndarray.h" #include "mxnet-cpp/executor.h" namespace mxnet { namespace cpp { /*! * \brief Default function for monitor that computes statistics of the input tensor, * which is the mean absolute |x|/size(x) * \param x The input tensor * \return The statistics of the input tensor */ NDArray _default_monitor_func(const NDArray &x); /*! * \brief Monitor interface */ class Monitor { public: typedef std::function StatFunc; typedef std::tuple Stat; /*! * \brief Monitor constructor * \param interval Number of batches between printing. * \param pattern A regular expression specifying which tensors to monitor. * \param stat_func A function that computes statistics of tensors. Defaults to mean * absolute value |x|/size(x). */ Monitor(int interval, std::regex pattern = std::regex(".*"), StatFunc stat_func = _default_monitor_func); /*! * \brief install callback to executor. Supports installing to multiple executors. * \param exe The executor to install to. */ void install(Executor *exe); /*! * \brief Start collecting stats for current batch. Call before calling forward. */ void tic(); /*! * \brief End collecting for current batch and return results. Call after computation * of current batch. */ std::vector toc(); /*! * \brief End collecting and print results. */ void toc_print(); protected: int interval; std::regex pattern; StatFunc stat_func; std::vector exes; int step; bool activated; std::vector stats; static void executor_callback(const char *name, NDArrayHandle ndarray, void *monitor_ptr); }; } // namespace cpp } // namespace mxnet #endif // MXNET_CPP_MONITOR_H_