2021-12-07 05:59:16 +00:00
|
|
|
#include "taichi/ir/ir_builder.h"
|
|
|
|
|
#include "taichi/ir/statements.h"
|
|
|
|
|
#include "taichi/program/program.h"
|
|
|
|
|
|
2025-05-05 11:21:11 -04:00
|
|
|
int main() {
|
2021-12-07 05:59:16 +00:00
|
|
|
/*
|
|
|
|
|
import taichi as ti, numpy as np
|
|
|
|
|
ti.init()
|
|
|
|
|
#ti.init(print_ir = True)
|
|
|
|
|
|
|
|
|
|
n = 10
|
|
|
|
|
place = ti.field(dtype = ti.i32)
|
|
|
|
|
ti.root.pointer(ti.i, n).place(place)
|
|
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def init():
|
|
|
|
|
for index in range(n):
|
|
|
|
|
place[index] = index
|
|
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def ret() -> ti.i32:
|
|
|
|
|
sum = 0
|
|
|
|
|
for index in place:
|
|
|
|
|
sum = sum + place[index]
|
|
|
|
|
return sum
|
|
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def ext(ext_arr: ti.ext_arr()):
|
|
|
|
|
for index in place:
|
|
|
|
|
ext_arr[index] = place[index]
|
|
|
|
|
|
|
|
|
|
init()
|
|
|
|
|
print(ret())
|
|
|
|
|
ext_arr = np.zeros(n, np.int32)
|
|
|
|
|
ext(ext_arr)
|
|
|
|
|
#ext_arr = place.to_numpy()
|
|
|
|
|
print(ext_arr)
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
using namespace taichi;
|
|
|
|
|
using namespace lang;
|
2025-05-05 11:21:11 -04:00
|
|
|
auto program = Program(host_arch());
|
2023-01-18 19:06:57 +08:00
|
|
|
const auto &config = program.compile_config();
|
2021-12-07 05:59:16 +00:00
|
|
|
/*CompileConfig config_print_ir;
|
|
|
|
|
config_print_ir.print_ir = true;
|
|
|
|
|
prog_.config = config_print_ir;*/ // print_ir = True
|
|
|
|
|
|
|
|
|
|
int n = 10;
|
|
|
|
|
program.materialize_runtime();
|
|
|
|
|
auto *root = new SNode(0, SNodeType::root);
|
2023-07-19 19:12:11 +08:00
|
|
|
auto *pointer = &root->pointer(Axis(0), n);
|
2021-12-07 05:59:16 +00:00
|
|
|
auto *place = &pointer->insert_children(SNodeType::place);
|
|
|
|
|
place->dt = PrimitiveType::i32;
|
|
|
|
|
program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/false);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Kernel> kernel_init, kernel_ret, kernel_ext;
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
/*
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def init():
|
|
|
|
|
for index in range(n):
|
|
|
|
|
place[index] = index
|
|
|
|
|
*/
|
|
|
|
|
IRBuilder builder;
|
|
|
|
|
auto *zero = builder.get_int32(0);
|
|
|
|
|
auto *n_stmt = builder.get_int32(n);
|
2022-01-24 19:38:25 +08:00
|
|
|
auto *loop = builder.create_range_for(zero, n_stmt, 0, 4);
|
2021-12-07 05:59:16 +00:00
|
|
|
{
|
|
|
|
|
auto _ = builder.get_loop_guard(loop);
|
|
|
|
|
auto *index = builder.get_loop_index(loop);
|
|
|
|
|
auto *ptr = builder.create_global_ptr(place, {index});
|
|
|
|
|
builder.create_global_store(ptr, index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel_init =
|
|
|
|
|
std::make_unique<Kernel>(program, builder.extract_ir(), "init");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
/*
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def ret():
|
|
|
|
|
sum = 0
|
|
|
|
|
for index in place:
|
|
|
|
|
sum = sum + place[index];
|
|
|
|
|
return sum
|
|
|
|
|
*/
|
|
|
|
|
IRBuilder builder;
|
|
|
|
|
auto *sum = builder.create_local_var(PrimitiveType::i32);
|
2022-01-24 19:38:25 +08:00
|
|
|
auto *loop = builder.create_struct_for(pointer, 0, 4);
|
2021-12-07 05:59:16 +00:00
|
|
|
{
|
|
|
|
|
auto _ = builder.get_loop_guard(loop);
|
|
|
|
|
auto *index = builder.get_loop_index(loop);
|
|
|
|
|
auto *sum_old = builder.create_local_load(sum);
|
|
|
|
|
auto *place_index =
|
|
|
|
|
builder.create_global_load(builder.create_global_ptr(place, {index}));
|
|
|
|
|
builder.create_local_store(sum, builder.create_add(sum_old, place_index));
|
|
|
|
|
}
|
|
|
|
|
builder.create_return(builder.create_local_load(sum));
|
|
|
|
|
|
|
|
|
|
kernel_ret = std::make_unique<Kernel>(program, builder.extract_ir(), "ret");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
/*
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def ext(ext: ti.ext_arr()):
|
|
|
|
|
for index in place:
|
|
|
|
|
ext[index] = place[index];
|
|
|
|
|
# ext = place.to_numpy()
|
|
|
|
|
*/
|
|
|
|
|
IRBuilder builder;
|
2022-01-24 19:38:25 +08:00
|
|
|
auto *loop = builder.create_struct_for(pointer, 0, 4);
|
2021-12-07 05:59:16 +00:00
|
|
|
{
|
|
|
|
|
auto _ = builder.get_loop_guard(loop);
|
|
|
|
|
auto *index = builder.get_loop_index(loop);
|
|
|
|
|
auto *ext = builder.create_external_ptr(
|
2023-07-10 19:22:06 +08:00
|
|
|
builder.create_arg_load({0}, PrimitiveType::i32, true, 0), {index});
|
2021-12-07 05:59:16 +00:00
|
|
|
auto *place_index =
|
|
|
|
|
builder.create_global_load(builder.create_global_ptr(place, {index}));
|
|
|
|
|
builder.create_global_store(ext, place_index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel_ext = std::make_unique<Kernel>(program, builder.extract_ir(), "ext");
|
2023-01-13 15:08:47 +08:00
|
|
|
kernel_ext->insert_arr_param(get_data_type<int>(), /*total_dim=*/1, {n});
|
2023-02-24 15:14:06 +08:00
|
|
|
kernel_ext->finalize_params();
|
2021-12-07 05:59:16 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto ctx_init = kernel_init->make_launch_context();
|
|
|
|
|
auto ctx_ret = kernel_ret->make_launch_context();
|
|
|
|
|
auto ctx_ext = kernel_ext->make_launch_context();
|
|
|
|
|
std::vector<int> ext_arr(n);
|
2023-07-10 19:22:00 +08:00
|
|
|
ctx_ext.set_arg_external_array_with_shape({0}, taichi::uint64(ext_arr.data()),
|
2022-07-29 18:14:54 +08:00
|
|
|
n, {n});
|
2021-12-07 05:59:16 +00:00
|
|
|
|
2023-04-21 13:02:26 +08:00
|
|
|
{
|
|
|
|
|
const auto &compiled_kernel_data =
|
|
|
|
|
program.compile_kernel(config, program.get_device_caps(), *kernel_init);
|
|
|
|
|
program.launch_kernel(compiled_kernel_data, ctx_init);
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
const auto &compiled_kernel_data =
|
|
|
|
|
program.compile_kernel(config, program.get_device_caps(), *kernel_ret);
|
|
|
|
|
program.launch_kernel(compiled_kernel_data, ctx_ret);
|
|
|
|
|
std::cout << program.fetch_result<int>(0) << std::endl;
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
const auto &compiled_kernel_data =
|
|
|
|
|
program.compile_kernel(config, program.get_device_caps(), *kernel_ext);
|
|
|
|
|
program.launch_kernel(compiled_kernel_data, ctx_ext);
|
|
|
|
|
for (int i = 0; i < n; i++)
|
|
|
|
|
std::cout << ext_arr[i] << " ";
|
|
|
|
|
std::cout << std::endl;
|
|
|
|
|
}
|
2021-12-07 05:59:16 +00:00
|
|
|
}
|