/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2017 by Contributors * \file threaded_engine_test.cc * \brief threaded engine tests */ #include #include #include #include #include #include #include #include #include #include #include #include "../src/engine/engine_impl.h" #include "../include/test_util.h" /** * present the following workload * n = reads.size() * data[write] = (data[reads[0]] + ... data[reads[n]]) / n * std::this_thread::sleep_for(std::chrono::microsecons(time)); */ struct Workload { std::vector reads; int write; int time; }; static uint32_t seed_ = 0xdeadbeef; /** * generate a list of workloads */ void GenerateWorkload(int num_workloads, int num_var, int min_read, int max_read, int min_time, int max_time, std::vector* workloads) { workloads->clear(); workloads->resize(num_workloads); for (int i = 0; i < num_workloads; ++i) { auto& wl = workloads->at(i); wl.write = rand_r(&seed_) % num_var; int r = rand_r(&seed_); int num_read = min_read + (r % (max_read - min_read)); for (int j = 0; j < num_read; ++j) { wl.reads.push_back(rand_r(&seed_) % num_var); } wl.time = min_time + rand_r(&seed_) % (max_time - min_time); } } /** * evaluate a single workload */ void EvaluateWorload(const Workload& wl, std::vector* data) { double tmp = 0; for (int i : wl.reads) tmp += data->at(i); data->at(wl.write) = tmp / (wl.reads.size() + 1); if (wl.time > 0) { std::this_thread::sleep_for(std::chrono::microseconds(wl.time)); } } /** * evaluate a list of workload, return the time used */ double EvaluateWorloads(const std::vector& workloads, mxnet::Engine* engine, std::vector* data) { using namespace mxnet; double t = dmlc::GetTime(); std::vector vars; if (engine) { for (size_t i = 0; i < data->size(); ++i) { vars.push_back(engine->NewVariable()); } } for (const auto& wl : workloads) { if (wl.reads.size() == 0) continue; if (engine == NULL) { EvaluateWorload(wl, data); } else { auto func = [wl, data](RunContext ctx, Engine::CallbackOnComplete cb) { EvaluateWorload(wl, data); cb(); }; std::vector reads; for (auto i : wl.reads) { if (i != wl.write) reads.push_back(vars[i]); } engine->PushAsync(func, Context::CPU(), reads, {vars[wl.write]}); } } if (engine) { engine->WaitForAll(); } return dmlc::GetTime() - t; } TEST(Engine, RandSumExpr) { std::vector workloads; int num_repeat = 5; const int num_engine = 4; std::vector t(num_engine, 0.0); std::vector engine(num_engine); engine[0] = NULL; engine[1] = mxnet::engine::CreateNaiveEngine(); engine[2] = mxnet::engine::CreateThreadedEnginePooled(); engine[3] = mxnet::engine::CreateThreadedEnginePerDevice(); for (int repeat = 0; repeat < num_repeat; ++repeat) { srand(time(NULL) + repeat); int num_var = 100; GenerateWorkload(10000, num_var, 2, 20, 1, 10, &workloads); std::vector> data(num_engine); for (int i = 0; i < num_engine; ++i) { data[i].resize(num_var, 1.0); t[i] += EvaluateWorloads(workloads, engine[i], &data[i]); } for (int i = 1; i < num_engine; ++i) { for (int j = 0; j < num_var; ++j) EXPECT_EQ(data[0][j], data[i][j]); } LOG(INFO) << "data: " << data[0][1] << " " << data[0][2] << "..."; } LOG(INFO) << "baseline\t\t" << t[0] << " sec"; LOG(INFO) << "NaiveEngine\t\t" << t[1] << " sec"; LOG(INFO) << "ThreadedEnginePooled\t" << t[2] << " sec"; LOG(INFO) << "ThreadedEnginePerDevice\t" << t[3] << " sec"; } void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); } TEST(Engine, basics) { auto&& engine = mxnet::Engine::Get(); auto&& var = engine->NewVariable(); std::vector oprs; // Test #1 printf("============= Test #1 ==============\n"); for (int i = 0; i < 10; ++i) { oprs.push_back(engine->NewOperator( [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { Foo(ctx, i); std::this_thread::sleep_for(std::chrono::seconds{1}); cb(); }, {var}, {})); engine->Push(oprs.at(i), mxnet::Context{}); } engine->WaitForAll(); printf("Going to push delete\n"); // std::this_thread::sleep_for(std::chrono::seconds{1}); for (auto&& i : oprs) { engine->DeleteOperator(i); } engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var); engine->WaitForAll(); printf("============= Test #2 ==============\n"); var = engine->NewVariable(); oprs.clear(); for (int i = 0; i < 10; ++i) { oprs.push_back(engine->NewOperator( [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { Foo(ctx, i); std::this_thread::sleep_for(std::chrono::milliseconds{500}); cb(); }, {}, {var})); engine->Push(oprs.at(i), mxnet::Context{}); } // std::this_thread::sleep_for(std::chrono::seconds{1}); engine->WaitForAll(); for (auto&& i : oprs) { engine->DeleteOperator(i); } engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var); printf("============= Test #3 ==============\n"); var = engine->NewVariable(); oprs.clear(); engine->WaitForVar(var); engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var); engine->WaitForAll(); printf("============= Test #4 ==============\n"); var = engine->NewVariable(); oprs.clear(); oprs.push_back(engine->NewOperator( [](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { std::this_thread::sleep_for(std::chrono::seconds{2}); Foo(ctx, 42); cb(); }, {}, {var}, mxnet::FnProperty::kCopyFromGPU)); engine->Push(oprs.at(0), mxnet::Context{}); LOG(INFO) << "IO operator pushed, should wait for 2 seconds."; engine->WaitForVar(var); LOG(INFO) << "OK, here I am."; for (auto&& i : oprs) { engine->DeleteOperator(i); } engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var); engine->WaitForAll(); printf("============= Test #5 ==============\n"); var = engine->NewVariable(); oprs.clear(); oprs.push_back(engine->NewOperator( [](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { Foo(ctx, 42); std::this_thread::sleep_for(std::chrono::seconds{2}); cb(); }, {var}, {})); engine->Push(oprs.at(0), mxnet::Context{}); LOG(INFO) << "Operator pushed, should not wait."; engine->WaitForVar(var); LOG(INFO) << "OK, here I am."; engine->WaitForAll(); LOG(INFO) << "That was 2 seconds."; for (auto&& i : oprs) { engine->DeleteOperator(i); } engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var); engine->WaitForAll(); var = nullptr; oprs.clear(); LOG(INFO) << "All pass"; } #ifdef _OPENMP struct TestSaveAndRestoreOMPState { TestSaveAndRestoreOMPState() { omp_set_dynamic(false); } ~TestSaveAndRestoreOMPState() { omp_set_num_threads(nthreads_); omp_set_dynamic(dynamic_); } const int nthreads_ = omp_get_max_threads(); const int dynamic_ = omp_get_dynamic(); }; /*! * \brief This test checks that omp_set_num_threads implementation has thread-scope */ TEST(Engine, omp_threading_count_scope) { TestSaveAndRestoreOMPState omp_state; const int THREAD_COUNT = 10; std::shared_ptr ready = std::make_shared(); std::shared_ptr threads = std::make_shared(); std::atomic counter(0), correct(0); omp_set_dynamic(0); for (int x = 0; x < THREAD_COUNT; ++x) { std::string name = "thread: "; name += std::to_string(x + 1); ++counter; threads->create(name, false, [x, &counter, &correct](std::shared_ptr ready_ptr) -> int { const int thread_count = x + 1; omp_set_num_threads(thread_count); --counter; ready_ptr->wait(); CHECK_EQ(omp_get_max_threads(), thread_count); #pragma omp parallel for for (int i = 0; i < 100; ++i) { if (i == 50) { const int current_threads = omp_get_num_threads(); if (current_threads == thread_count) { ++correct; } } } return 0; }, ready); } while (counter.load() > 0) { usleep(100); } ready->signal(); threads->join_all(); GTEST_ASSERT_EQ(correct.load(), THREAD_COUNT); } #endif // _OPENMP