2021-06-23 17:12:25 +08:00
|
|
|
import taichi as ti
|
2022-02-10 12:37:36 +08:00
|
|
|
from tests import test_utils
|
2021-06-23 17:12:25 +08:00
|
|
|
|
2020-11-12 20:32:56 +08:00
|
|
|
|
2022-02-10 12:37:36 +08:00
|
|
|
@test_utils.test(require=ti.extension.quant, debug=True)
|
2022-07-06 16:45:40 +08:00
|
|
|
def test_1D_quant_array():
|
2022-06-15 09:43:49 +08:00
|
|
|
qu1 = ti.types.quant.int(1, False)
|
2020-11-12 20:32:56 +08:00
|
|
|
|
2022-06-15 09:43:49 +08:00
|
|
|
x = ti.field(dtype=qu1)
|
2020-11-12 20:32:56 +08:00
|
|
|
|
|
|
|
|
N = 32
|
|
|
|
|
|
2022-07-29 18:08:10 +08:00
|
|
|
ti.root.quant_array(ti.i, N, max_num_bits=32).place(x)
|
2020-11-12 20:32:56 +08:00
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def set_val():
|
|
|
|
|
for i in range(N):
|
|
|
|
|
x[i] = i % 2
|
|
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def verify_val():
|
|
|
|
|
for i in range(N):
|
|
|
|
|
assert x[i] == i % 2
|
|
|
|
|
|
|
|
|
|
set_val()
|
|
|
|
|
verify_val()
|
|
|
|
|
|
|
|
|
|
|
2022-06-24 08:55:49 +08:00
|
|
|
@test_utils.test(require=ti.extension.quant, debug=True)
|
2022-07-06 16:45:40 +08:00
|
|
|
def test_1D_quant_array_negative():
|
2022-06-24 08:55:49 +08:00
|
|
|
N = 4
|
|
|
|
|
qi7 = ti.types.quant.int(7)
|
|
|
|
|
x = ti.field(dtype=qi7)
|
2022-07-29 18:08:10 +08:00
|
|
|
ti.root.quant_array(ti.i, N, max_num_bits=32).place(x)
|
2022-06-24 08:55:49 +08:00
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def assign():
|
|
|
|
|
for i in range(N):
|
|
|
|
|
assert x[i] == 0
|
|
|
|
|
x[i] = -i
|
|
|
|
|
assert x[i] == -i
|
|
|
|
|
|
|
|
|
|
assign()
|
|
|
|
|
|
|
|
|
|
|
2022-07-11 17:43:30 +08:00
|
|
|
@test_utils.test(require=ti.extension.quant, debug=True)
|
|
|
|
|
def test_1D_quant_array_fixed():
|
2022-07-29 18:08:10 +08:00
|
|
|
qfxt = ti.types.quant.fixed(bits=8, max_value=2)
|
2022-07-11 17:43:30 +08:00
|
|
|
|
|
|
|
|
x = ti.field(dtype=qfxt)
|
|
|
|
|
|
|
|
|
|
N = 4
|
|
|
|
|
|
2022-07-29 18:08:10 +08:00
|
|
|
ti.root.quant_array(ti.i, N, max_num_bits=32).place(x)
|
2022-07-11 17:43:30 +08:00
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def set_val():
|
|
|
|
|
for i in range(N):
|
|
|
|
|
x[i] = i * 0.5
|
|
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def verify_val():
|
|
|
|
|
for i in range(N):
|
|
|
|
|
assert x[i] == i * 0.5
|
|
|
|
|
|
|
|
|
|
set_val()
|
|
|
|
|
verify_val()
|
|
|
|
|
|
|
|
|
|
|
2022-02-10 12:37:36 +08:00
|
|
|
@test_utils.test(require=ti.extension.quant, debug=True)
|
2022-07-06 16:45:40 +08:00
|
|
|
def test_2D_quant_array():
|
2022-06-24 08:55:49 +08:00
|
|
|
qu1 = ti.types.quant.int(1, False)
|
2020-11-12 20:32:56 +08:00
|
|
|
|
2022-06-24 08:55:49 +08:00
|
|
|
x = ti.field(dtype=qu1)
|
2020-11-12 20:32:56 +08:00
|
|
|
|
|
|
|
|
M, N = 4, 8
|
|
|
|
|
|
2022-07-29 18:08:10 +08:00
|
|
|
ti.root.quant_array(ti.ij, (M, N), max_num_bits=32).place(x)
|
2020-11-12 20:32:56 +08:00
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def set_val():
|
|
|
|
|
for i in range(M):
|
|
|
|
|
for j in range(N):
|
|
|
|
|
x[i, j] = (i * N + j) % 2
|
|
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def verify_val():
|
|
|
|
|
for i in range(M):
|
|
|
|
|
for j in range(N):
|
|
|
|
|
assert x[i, j] == (i * N + j) % 2
|
|
|
|
|
|
|
|
|
|
set_val()
|
|
|
|
|
verify_val()
|
2022-06-27 17:33:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@test_utils.test(require=ti.extension.quant, debug=True)
|
2022-07-06 16:45:40 +08:00
|
|
|
def test_quant_array_struct_for():
|
2022-06-27 17:33:11 +08:00
|
|
|
block_size = 16
|
|
|
|
|
N = 64
|
|
|
|
|
cell = ti.root.pointer(ti.i, N // block_size)
|
|
|
|
|
qi7 = ti.types.quant.int(7)
|
|
|
|
|
|
|
|
|
|
x = ti.field(dtype=qi7)
|
2022-07-06 16:45:40 +08:00
|
|
|
cell.dense(ti.i, block_size // 4).quant_array(ti.i, 4, max_num_bits=32).place(x)
|
2022-06-27 17:33:11 +08:00
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def activate():
|
|
|
|
|
for i in range(N):
|
|
|
|
|
if i // block_size % 2 == 0:
|
|
|
|
|
x[i] = i
|
|
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def assign():
|
|
|
|
|
for i in x:
|
|
|
|
|
x[i] -= 1
|
|
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def verify():
|
|
|
|
|
for i in range(N):
|
|
|
|
|
if i // block_size % 2 == 0:
|
|
|
|
|
assert x[i] == i - 1
|
|
|
|
|
else:
|
|
|
|
|
assert x[i] == 0
|
|
|
|
|
|
|
|
|
|
activate()
|
|
|
|
|
assign()
|
|
|
|
|
verify()
|