2020-07-16 11:26:03 +08:00
|
|
|
import numpy as np
|
2020-11-16 02:50:50 +08:00
|
|
|
import pytest
|
2019-06-20 18:37:28 -07:00
|
|
|
|
2021-06-23 17:12:25 +08:00
|
|
|
import taichi as ti
|
2022-02-10 12:37:36 +08:00
|
|
|
from tests import test_utils
|
2021-06-23 17:12:25 +08:00
|
|
|
|
2019-12-16 07:58:06 -05:00
|
|
|
|
2020-11-16 02:50:50 +08:00
|
|
|
@pytest.mark.parametrize("val", [0, 1])
|
2022-02-10 12:37:36 +08:00
|
|
|
@test_utils.test(ti.cpu)
|
2020-11-16 02:50:50 +08:00
|
|
|
def test_static_if(val):
|
|
|
|
|
x = ti.field(ti.i32)
|
2019-06-20 18:37:28 -07:00
|
|
|
|
2020-11-16 02:50:50 +08:00
|
|
|
ti.root.dense(ti.i, 1).place(x)
|
2019-06-20 18:37:28 -07:00
|
|
|
|
2020-11-16 02:50:50 +08:00
|
|
|
@ti.kernel
|
|
|
|
|
def static():
|
|
|
|
|
if ti.static(val > 0.5):
|
|
|
|
|
x[0] = 1
|
|
|
|
|
else:
|
|
|
|
|
x[0] = 0
|
2019-06-20 18:37:28 -07:00
|
|
|
|
2020-11-16 02:50:50 +08:00
|
|
|
static()
|
|
|
|
|
assert x[0] == val
|
2019-06-20 18:37:28 -07:00
|
|
|
|
2019-11-27 14:24:41 +08:00
|
|
|
|
2022-02-10 12:37:36 +08:00
|
|
|
@test_utils.test(ti.cpu)
|
2019-11-27 14:24:41 +08:00
|
|
|
def test_static_if_error():
|
2020-08-13 11:24:48 +08:00
|
|
|
x = ti.field(ti.i32)
|
2019-11-27 14:24:41 +08:00
|
|
|
|
2020-05-31 16:55:38 -04:00
|
|
|
ti.root.dense(ti.i, 1).place(x)
|
2019-11-27 14:24:41 +08:00
|
|
|
|
|
|
|
|
@ti.kernel
|
2020-11-16 02:50:50 +08:00
|
|
|
def static(val: float):
|
|
|
|
|
if ti.static(val > 0.5):
|
2019-11-27 14:24:41 +08:00
|
|
|
x[0] = 1
|
|
|
|
|
else:
|
|
|
|
|
x[0] = 0
|
|
|
|
|
|
2021-11-30 13:01:31 +08:00
|
|
|
with pytest.raises(ti.TaichiCompilationError, match="must be compile-time constants"):
|
2020-11-16 02:50:50 +08:00
|
|
|
static(42)
|
2020-03-21 13:21:49 -04:00
|
|
|
|
|
|
|
|
|
2022-02-10 12:37:36 +08:00
|
|
|
@test_utils.test()
|
2020-03-21 13:21:49 -04:00
|
|
|
def test_static_ndrange():
|
|
|
|
|
n = 3
|
2020-08-14 16:50:01 +08:00
|
|
|
x = ti.Matrix.field(n, n, dtype=ti.f32, shape=(n, n))
|
2020-03-21 13:21:49 -04:00
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def fill():
|
|
|
|
|
w = [0, 1, 2]
|
|
|
|
|
for i, j in ti.static(ti.ndrange(3, 3)):
|
|
|
|
|
x[i, j][i, j] = w[i] + w[j] * 2
|
|
|
|
|
|
|
|
|
|
fill()
|
|
|
|
|
for i in range(3):
|
|
|
|
|
for j in range(3):
|
|
|
|
|
assert x[i, j][i, j] == i + j * 2
|
2020-07-16 11:26:03 +08:00
|
|
|
|
|
|
|
|
|
2022-02-10 12:37:36 +08:00
|
|
|
@test_utils.test(ti.cpu)
|
2020-07-16 11:26:03 +08:00
|
|
|
def test_static_break():
|
2020-08-13 11:24:48 +08:00
|
|
|
x = ti.field(ti.i32, 5)
|
2020-07-16 11:26:03 +08:00
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def func():
|
|
|
|
|
for i in ti.static(range(5)):
|
|
|
|
|
x[i] = 1
|
|
|
|
|
if ti.static(i == 2):
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
func()
|
|
|
|
|
|
|
|
|
|
assert np.allclose(x.to_numpy(), np.array([1, 1, 1, 0, 0]))
|
|
|
|
|
|
|
|
|
|
|
2022-02-10 12:37:36 +08:00
|
|
|
@test_utils.test(ti.cpu)
|
2020-07-16 11:26:03 +08:00
|
|
|
def test_static_continue():
|
2020-08-13 11:24:48 +08:00
|
|
|
x = ti.field(ti.i32, 5)
|
2020-07-16 11:26:03 +08:00
|
|
|
|
|
|
|
|
@ti.kernel
|
|
|
|
|
def func():
|
|
|
|
|
for i in ti.static(range(5)):
|
|
|
|
|
if ti.static(i == 2):
|
|
|
|
|
continue
|
|
|
|
|
x[i] = 1
|
|
|
|
|
|
|
|
|
|
func()
|
|
|
|
|
|
|
|
|
|
assert np.allclose(x.to_numpy(), np.array([1, 1, 0, 1, 1]))
|