代码中我们常会使用continue跳过一些异常情况,如下的结构中else
往往都会被省略。
if ...: continue
# else:
但我debug了一晚上发现,在taichi kernel中省略此处的else
有时会导致奇怪的bug出现。。
一开始我是发现在kernel中加入完全没使用到的print()
函数之后,bug就消失了。。甚至什么都不print都行。。排除自己的问题之后,最后怀疑是不是编译器出了什么问题,最终定位在if .. continue
的结构上,发现问题在于此处省略了else
。。
重现脚本如下,是一个简单的ray tracing程序,场景中有几个axis aligned box。
首先发射光线,再对每束光线和AABB相交测试。
正常情况下应该能看到4个黑方块,bug的话是黑屏。
省略else 有 print 或者 不省略else都是正常情况
省略else 无 print 是黑屏情况
import taichi as ti
ti.init(arch=ti.cuda, debug=True)
inf = 1e10
vec3 = ti.types.vector(3, ti.f32)
@ti.func
def ray_aabb_intersection_f(box_min, box_max, o, d):
intersect = 1
near_int = -inf
far_int = inf
for i in ti.static(range(3)):
if d[i] == 0:
if o[i] < box_min[i] or o[i] > box_max[i]:
intersect = 0
else:
i1 = (box_min[i] - o[i]) / d[i]
i2 = (box_max[i] - o[i]) / d[i]
new_far_int = ti.max(i1, i2)
new_near_int = ti.min(i1, i2)
far_int = ti.min(new_far_int, far_int)
near_int = ti.max(new_near_int, near_int)
if near_int > far_int:
intersect = 0
return intersect, near_int, far_int
@ti.data_oriented
class Test:
def __init__(self):
self.ray_dir = ti.Vector.field(3, ti.f32, shape=(1024, 720))
self.ray_depth = ti.field(ti.f32, shape=(1024, 720))
self.ray_depth.fill(inf)
self.grid = ti.Vector.field(3, ti.f32, shape = 4) # 4 cube
self.grid[0] = vec3(1.2, 1.2, 3.)
self.grid[1] = vec3(-1.2, -1.2, 4.)
self.grid[2] = vec3(-1.2, 1.2, 5.)
self.grid[3] = vec3(1.2, -1.2, 6.)
self.block_size = vec3(1.1, 1.1, 1.1)
@ti.kernel
def block_intersect_k(self, n_block:ti.i32, gridinfo:ti.template(), block_size:vec3):
o = vec3(0., 0., 0.)
for I in ti.grouped(self.ray_dir):
d = self.ray_dir[I]
for n in range(n_block):
box_min = gridinfo[n]
box_max = box_min + block_size
itx, n_itx, f_itx = ray_aabb_intersection_f(box_min, box_max, o, d)
if itx == 0 or f_itx < 0:
continue
# else: ##### CANNOT BE OMITTED!!! or you need print('')... #####
if n_itx < 0: # ray origin inside the block
print('') ######## MAGIC LINE!! comment it to see the bug. #########
self.ray_depth[I] = 0
else:
if n_itx < self.ray_depth[I]:
self.ray_depth[I] = 0
@ti.kernel
def cast_rays_k(self, fx:ti.f32, fy:ti.f32, cx:ti.f32, cy:ti.f32):
for i,j in self.ray_dir:
# i for w, j for h
x = (i - cx + 0.5) / fx
y = (j - cy + 0.5) / fy
dir = ti.Vector([x, y, 1.0])
self.ray_dir[i, j] = dir.normalized()
win = ti.ui.Window('test', (1024, 720))
canvas = win.get_canvas()
t = Test()
print(t.grid)
while win.running:
t.cast_rays_k(1024, 720, 512, 360)
t.block_intersect_k(4, t.grid, t.block_size)
canvas.set_image(t.ray_depth)
win.show()
复现环境
[Taichi] version 1.2.0, llvm 10.0.0, commit f189fd79, win, python 3.8.8