2022-05-20 08:27:28 +08:00
|
|
|
#include "tests/cpp/ir/ndarray_kernel.h"
|
|
|
|
|
|
2022-09-24 12:16:53 +08:00
|
|
|
namespace taichi::lang {
|
2022-05-20 08:27:28 +08:00
|
|
|
|
|
|
|
|
std::unique_ptr<Kernel> setup_kernel1(Program *prog) {
|
|
|
|
|
IRBuilder builder1;
|
|
|
|
|
{
|
2023-07-10 19:22:06 +08:00
|
|
|
auto *arg = builder1.create_ndarray_arg_load(
|
|
|
|
|
/*arg_id=*/{0}, get_data_type<int>(), 1, 0);
|
2022-05-20 08:27:28 +08:00
|
|
|
auto *zero = builder1.get_int32(0);
|
|
|
|
|
auto *one = builder1.get_int32(1);
|
|
|
|
|
auto *two = builder1.get_int32(2);
|
|
|
|
|
auto *a1ptr = builder1.create_external_ptr(arg, {one});
|
|
|
|
|
builder1.create_global_store(a1ptr, one); // a[1] = 1
|
|
|
|
|
auto *a0 =
|
|
|
|
|
builder1.create_global_load(builder1.create_external_ptr(arg, {zero}));
|
|
|
|
|
auto *a2ptr = builder1.create_external_ptr(arg, {two});
|
|
|
|
|
auto *a2 = builder1.create_global_load(a2ptr);
|
|
|
|
|
auto *a0plusa2 = builder1.create_add(a0, a2);
|
|
|
|
|
builder1.create_global_store(a2ptr, a0plusa2); // a[2] = a[0] + a[2]
|
|
|
|
|
}
|
|
|
|
|
auto block = builder1.extract_ir();
|
|
|
|
|
auto ker1 = std::make_unique<Kernel>(*prog, std::move(block), "ker1");
|
2023-05-25 17:01:34 +08:00
|
|
|
ker1->insert_ndarray_param(get_data_type<int>(), /*total_dim=*/1);
|
2023-02-24 15:14:06 +08:00
|
|
|
ker1->finalize_params();
|
2023-04-06 11:17:56 +08:00
|
|
|
ker1->finalize_rets();
|
2022-05-20 08:27:28 +08:00
|
|
|
return ker1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Kernel> setup_kernel2(Program *prog) {
|
|
|
|
|
IRBuilder builder2;
|
|
|
|
|
|
|
|
|
|
{
|
2023-07-10 19:22:06 +08:00
|
|
|
auto *arg0 = builder2.create_ndarray_arg_load(
|
|
|
|
|
/*arg_id=*/{0}, get_data_type<int>(), 1, 0);
|
2023-07-10 19:22:00 +08:00
|
|
|
auto *arg1 = builder2.create_arg_load(/*arg_id=*/{1}, get_data_type<int>(),
|
2023-07-10 19:22:06 +08:00
|
|
|
/*is_ptr=*/false, /*arg_depth=*/0);
|
2022-05-20 08:27:28 +08:00
|
|
|
auto *one = builder2.get_int32(1);
|
|
|
|
|
auto *a1ptr = builder2.create_external_ptr(arg0, {one});
|
|
|
|
|
builder2.create_global_store(a1ptr, arg1); // a[1] = arg1
|
|
|
|
|
}
|
|
|
|
|
auto block2 = builder2.extract_ir();
|
|
|
|
|
auto ker2 = std::make_unique<Kernel>(*prog, std::move(block2), "ker2");
|
2023-05-25 17:01:34 +08:00
|
|
|
ker2->insert_ndarray_param(get_data_type<int>(), /*total_dim=*/1);
|
2023-01-13 15:08:47 +08:00
|
|
|
ker2->insert_scalar_param(get_data_type<int>());
|
2023-02-24 15:14:06 +08:00
|
|
|
ker2->finalize_params();
|
2023-04-06 11:17:56 +08:00
|
|
|
ker2->finalize_rets();
|
2022-05-20 08:27:28 +08:00
|
|
|
return ker2;
|
|
|
|
|
}
|
2022-09-24 12:16:53 +08:00
|
|
|
} // namespace taichi::lang
|