/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file operator.hpp * \brief implementation of data iter * \author Zhang Chen */ #ifndef MXNET_CPP_IO_HPP_ #define MXNET_CPP_IO_HPP_ #include #include #include "mxnet-cpp/io.h" namespace mxnet { namespace cpp { inline MXDataIterMap*& MXDataIter::mxdataiter_map() { static MXDataIterMap* mxdataiter_map_ = new MXDataIterMap; return mxdataiter_map_; } inline MXDataIter::MXDataIter(const std::string &mxdataiter_type) { creator_ = mxdataiter_map()->GetMXDataIterCreator(mxdataiter_type); blob_ptr_ = std::make_shared(nullptr); } inline void MXDataIter::BeforeFirst() { int r = MXDataIterBeforeFirst(blob_ptr_->handle_); CHECK_EQ(r, 0); } inline bool MXDataIter::Next() { int out; int r = MXDataIterNext(blob_ptr_->handle_, &out); CHECK_EQ(r, 0); return out; } inline NDArray MXDataIter::GetData() { NDArrayHandle handle; int r = MXDataIterGetData(blob_ptr_->handle_, &handle); CHECK_EQ(r, 0); return NDArray(handle); } inline NDArray MXDataIter::GetLabel() { NDArrayHandle handle; int r = MXDataIterGetLabel(blob_ptr_->handle_, &handle); CHECK_EQ(r, 0); return NDArray(handle); } inline int MXDataIter::GetPadNum() { int out; int r = MXDataIterGetPadNum(blob_ptr_->handle_, &out); CHECK_EQ(r, 0); return out; } inline std::vector MXDataIter::GetIndex() { uint64_t *out_index, out_size; int r = MXDataIterGetIndex(blob_ptr_->handle_, &out_index, &out_size); CHECK_EQ(r, 0); std::vector ret; for (uint64_t i = 0; i < out_size; ++i) { ret.push_back(out_index[i]); } return ret; } inline MXDataIter MXDataIter::CreateDataIter() { std::vector param_keys; std::vector param_values; for (auto &data : params_) { param_keys.push_back(data.first.c_str()); param_values.push_back(data.second.c_str()); } MXDataIterCreateIter(creator_, param_keys.size(), param_keys.data(), param_values.data(), &blob_ptr_->handle_); return *this; } // MXDataIter MNIst } // namespace cpp } // namespace mxnet #endif // MXNET_CPP_IO_HPP_