#include "taichi/ir/ir_builder.h" #include "taichi/ir/statements.h" #include "taichi/program/program.h" int main() { /* import taichi as ti, numpy as np ti.init() #ti.init(print_ir = True) n = 10 place = ti.field(dtype = ti.i32) ti.root.pointer(ti.i, n).place(place) @ti.kernel def init(): for index in range(n): place[index] = index @ti.kernel def ret() -> ti.i32: sum = 0 for index in place: sum = sum + place[index] return sum @ti.kernel def ext(ext_arr: ti.ext_arr()): for index in place: ext_arr[index] = place[index] init() print(ret()) ext_arr = np.zeros(n, np.int32) ext(ext_arr) #ext_arr = place.to_numpy() print(ext_arr) */ using namespace taichi; using namespace lang; auto program = Program(host_arch()); const auto &config = program.compile_config(); /*CompileConfig config_print_ir; config_print_ir.print_ir = true; prog_.config = config_print_ir;*/ // print_ir = True int n = 10; program.materialize_runtime(); auto *root = new SNode(0, SNodeType::root); auto *pointer = &root->pointer(Axis(0), n); auto *place = &pointer->insert_children(SNodeType::place); place->dt = PrimitiveType::i32; program.add_snode_tree(std::unique_ptr(root), /*compile_only=*/false); std::unique_ptr kernel_init, kernel_ret, kernel_ext; { /* @ti.kernel def init(): for index in range(n): place[index] = index */ IRBuilder builder; auto *zero = builder.get_int32(0); auto *n_stmt = builder.get_int32(n); auto *loop = builder.create_range_for(zero, n_stmt, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *index = builder.get_loop_index(loop); auto *ptr = builder.create_global_ptr(place, {index}); builder.create_global_store(ptr, index); } kernel_init = std::make_unique(program, builder.extract_ir(), "init"); } { /* @ti.kernel def ret(): sum = 0 for index in place: sum = sum + place[index]; return sum */ IRBuilder builder; auto *sum = builder.create_local_var(PrimitiveType::i32); auto *loop = builder.create_struct_for(pointer, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *index = builder.get_loop_index(loop); auto *sum_old = builder.create_local_load(sum); auto *place_index = builder.create_global_load(builder.create_global_ptr(place, {index})); builder.create_local_store(sum, builder.create_add(sum_old, place_index)); } builder.create_return(builder.create_local_load(sum)); kernel_ret = std::make_unique(program, builder.extract_ir(), "ret"); } { /* @ti.kernel def ext(ext: ti.ext_arr()): for index in place: ext[index] = place[index]; # ext = place.to_numpy() */ IRBuilder builder; auto *loop = builder.create_struct_for(pointer, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *index = builder.get_loop_index(loop); auto *ext = builder.create_external_ptr( builder.create_arg_load({0}, PrimitiveType::i32, true, 0), {index}); auto *place_index = builder.create_global_load(builder.create_global_ptr(place, {index})); builder.create_global_store(ext, place_index); } kernel_ext = std::make_unique(program, builder.extract_ir(), "ext"); kernel_ext->insert_arr_param(get_data_type(), /*total_dim=*/1, {n}); kernel_ext->finalize_params(); } auto ctx_init = kernel_init->make_launch_context(); auto ctx_ret = kernel_ret->make_launch_context(); auto ctx_ext = kernel_ext->make_launch_context(); std::vector ext_arr(n); ctx_ext.set_arg_external_array_with_shape({0}, taichi::uint64(ext_arr.data()), n, {n}); { const auto &compiled_kernel_data = program.compile_kernel(config, program.get_device_caps(), *kernel_init); program.launch_kernel(compiled_kernel_data, ctx_init); } { const auto &compiled_kernel_data = program.compile_kernel(config, program.get_device_caps(), *kernel_ret); program.launch_kernel(compiled_kernel_data, ctx_ret); std::cout << program.fetch_result(0) << std::endl; } { const auto &compiled_kernel_data = program.compile_kernel(config, program.get_device_caps(), *kernel_ext); program.launch_kernel(compiled_kernel_data, ctx_ext); for (int i = 0; i < n; i++) std::cout << ext_arr[i] << " "; std::cout << std::endl; } }