if ... continue 后的else竟然不能省略?!编译器似乎有点小问题…

代码中我们常会使用continue跳过一些异常情况,如下的结构中else往往都会被省略。

if ...: continue
# else:

但我debug了一晚上发现,在taichi kernel中省略此处的else有时会导致奇怪的bug出现。。 :face_holding_back_tears: :face_holding_back_tears: :face_holding_back_tears:

一开始我是发现在kernel中加入完全没使用到的print()函数之后,bug就消失了。。甚至什么都不print都行。。排除自己的问题之后,最后怀疑是不是编译器出了什么问题,最终定位在if .. continue 的结构上,发现问题在于此处省略了else。。 :face_with_head_bandage: :face_with_head_bandage: :face_with_head_bandage:

重现脚本如下,是一个简单的ray tracing程序,场景中有几个axis aligned box。
首先发射光线,再对每束光线和AABB相交测试。
正常情况下应该能看到4个黑方块,bug的话是黑屏。

省略else 有 print 或者 不省略else都是正常情况
省略else 无 print 是黑屏情况

import taichi as ti

ti.init(arch=ti.cuda, debug=True)

inf = 1e10
vec3 = ti.types.vector(3, ti.f32)

@ti.func
def ray_aabb_intersection_f(box_min, box_max, o, d):
    intersect = 1

    near_int = -inf
    far_int = inf

    for i in ti.static(range(3)):
        if d[i] == 0:
            if o[i] < box_min[i] or o[i] > box_max[i]:
                intersect = 0
        else:
            i1 = (box_min[i] - o[i]) / d[i]
            i2 = (box_max[i] - o[i]) / d[i]

            new_far_int = ti.max(i1, i2)
            new_near_int = ti.min(i1, i2)

            far_int = ti.min(new_far_int, far_int)
            near_int = ti.max(new_near_int, near_int)

    if near_int > far_int:
        intersect = 0
    return intersect, near_int, far_int

@ti.data_oriented
class Test:
    def __init__(self):
        self.ray_dir = ti.Vector.field(3, ti.f32, shape=(1024, 720))
        self.ray_depth = ti.field(ti.f32, shape=(1024, 720))
        self.ray_depth.fill(inf)
        self.grid = ti.Vector.field(3, ti.f32, shape = 4) # 4 cube
        self.grid[0] = vec3(1.2, 1.2, 3.)
        self.grid[1] = vec3(-1.2, -1.2, 4.)
        self.grid[2] = vec3(-1.2, 1.2, 5.)
        self.grid[3] = vec3(1.2, -1.2, 6.)
        self.block_size = vec3(1.1, 1.1, 1.1)

    @ti.kernel
    def block_intersect_k(self, n_block:ti.i32, gridinfo:ti.template(), block_size:vec3):
        o = vec3(0., 0., 0.)
        for I in ti.grouped(self.ray_dir):
            d = self.ray_dir[I]
            for n in range(n_block):
                box_min = gridinfo[n]
                box_max = box_min + block_size
                itx, n_itx, f_itx = ray_aabb_intersection_f(box_min, box_max, o, d)
                if itx == 0 or f_itx < 0:
                    continue
                # else:                 ##### CANNOT BE OMITTED!!!  or you need print('')... #####
                if n_itx < 0:           # ray origin inside the block
                    print('')           ######## MAGIC LINE!! comment it to see the bug. #########
                    self.ray_depth[I] = 0
                else:
                    if n_itx < self.ray_depth[I]:
                        self.ray_depth[I] = 0
    @ti.kernel
    def cast_rays_k(self, fx:ti.f32, fy:ti.f32, cx:ti.f32, cy:ti.f32):
        for i,j in self.ray_dir:
            # i for w, j for h
            x = (i - cx + 0.5) / fx
            y = (j - cy + 0.5) / fy
            dir = ti.Vector([x, y, 1.0])
            self.ray_dir[i, j] = dir.normalized()

win = ti.ui.Window('test', (1024, 720))
canvas = win.get_canvas()
t = Test()
print(t.grid)
while win.running:
    t.cast_rays_k(1024, 720, 512, 360)
    t.block_intersect_k(4, t.grid, t.block_size)
    canvas.set_image(t.ray_depth)
    win.show()

复现环境

[Taichi] version 1.2.0, llvm 10.0.0, commit f189fd79, win, python 3.8.8

我使用taichi 1.2.1注释掉那行print好像也看到了那4个黑方块,你可以试一下升级到1.2.1看看这个bug有没有被修好?另外最后一行的win.show应该是小写的?

诶 升级到1.2.1确实就好了…
最后一行确实是typo,已修改。。

1 个赞