if-else语句 相关的不连续 在taichi的可微计算中,是可微的么,如何实现的?

我在尝试写一个可微光线跟踪的代码,需要处理物体之间的阴影遮挡
发射很多根光线
kernel函数中:
for each ray:
if ray is shadowed:
radiance=0
else:
radiance=f(ray,point)#继续迭代计算能量

我想知道在taichi中if-else语句相关的不连续问题在自动微分中是可微的么?上面这段伪代码在Taichi中是可微的么?

我搜了下资料,在一个文档里看到Flatten Branching这个概念,有没有懂的大佬帮我解释下这个方法 :pray: :pray: 先扁平化分支,再消除局部变量?这如何使if-else可微的呀?

我也遇到相同的问题,在使用自动微分的时候,如果使用if语句处理碰撞,会导致梯度输出nan,但是我注意到difftaichi的example中的diffmpm的kernel:grid_op()使用了if语句,并且这个kernel也是需要计算kernel.grad的。

@ti.kernel
def grid_op():
    for i, j in grid_m_in:
        inv_m = 1 / (grid_m_in[i, j] + 1e-10)
        v_out = inv_m * grid_v_in[i, j]
        v_out[1] -= dt * gravity
        if i < bound and v_out[0] < 0:
            v_out[0] = 0
            v_out[1] = 0
        if i > n_grid - bound and v_out[0] > 0:
            v_out[0] = 0
            v_out[1] = 0
        if j < bound and v_out[1] < 0:
            v_out[0] = 0
            v_out[1] = 0
            normal = ti.Vector([0.0, 1.0])
            lsq = (normal**2).sum()
            if lsq > 0.5:
                if ti.static(coeff < 0):
                    v_out[0] = 0
                    v_out[1] = 0
                else:
                    lin = v_out.dot(normal)
                    if lin < 0:
                        vit = v_out - lin * normal
                        lit = vit.norm() + 1e-10
                        if lit + coeff * lin <= 0:
                            v_out[0] = 0
                            v_out[1] = 0
                        else:
                            v_out = (1 + coeff * lin / lit) * vit
        if j > n_grid - bound and v_out[1] > 0:
            v_out[0] = 0
            v_out[1] = 0

        grid_v_out[i, j] = v_out

我不知道if语句如何能在autodiff中正确使用