求助!if分支太多要爆炸!

光追代码,定义了sdf物体:

@ti.dataclass
class SDFObject:
    type: int
    transform: Transform
    material: Material

其中type表示形状,比如球,圆柱,圆面,方形平面等。
于是写了这样的函数:

@ti.func
def sdf_all(index: int, p: vec3, b: vec3) -> float:
    result = MAX_DIS
    if index == SHAPE.SPHERE:
        result = sd_sphere(p, b)
    elif index == SHAPE.BOX:
        result = sd_box(p, b)
    elif index == SHAPE.CYLINDER:
        result = sd_cylinder(p, b)
    elif index == SHAPE.CONE:
        result = sd_cone(p, b)
    elif index == SHAPE.PLANE_INFINITY:
        result = sd_plane_infinity(p, b)
    elif index == SHAPE.PLANE_CIRCLE:
        result = sd_plane_circle(p, b)
    elif index == SHAPE.PLANE_TRIANGLE:
        result = sd_plane_triangle(p, b)
    elif index == SHAPE.PLANE_TRIANGLE_R:
        result = sd_plane_right_triangle(p, b)
    elif index == SHAPE.PLANE_SQUARE:
        result = sd_plane_square(p, b)
    return result

这么多if分支,直接让我的耗时增加十倍百倍!

于是我想办法优化,先是用列表:

SHAPE_FUNC = [
    sd_none,
    sd_sphere,
    sd_box,
    sd_cylinder,
    sd_cone,
    sd_plane_infinity,
    sd_plane_circle,
    sd_plane_triangle,
    sd_plane_right_triangle,
    sd_plane_square
]


@ti.func
def sdf_all(index: int, p: vec3, b: vec3) -> float:
    result = MAX_DIS
    for i in ti.static(range(len(SHAPE_FUNC))):
        if index == i:
            result = ti.static(SHAPE_FUNC[i])(p, b)
    return result

没有变化。
再尝试二分法:

@ti.func
def sdf_all(index: int, p: vec3, b: vec3) -> float:
    result = MAX_DIS

    # 二分查找式的分支,减少平均分支数
    if index < 5:
        if index < 3:
            if index == 0:
                result = sd_none(p, b)
            elif index == 1:
                result = sd_sphere(p, b)
            else:  # index == 2
                result = sd_box(p, b)
        else:  # index >= 3 and index < 5
            if index == 3:
                result = sd_cylinder(p, b)
            else:  # index == 4
                result = sd_cone(p, b)
    else:  # index >= 5
        if index < 8:
            if index == 5:
                result = sd_plane_infinity(p, b)
            elif index == 6:
                result = sd_plane_circle(p, b)
            else:  # index == 7
                result = sd_plane_triangle(p, b)
        else:  # index >= 8
            if index == 8:
                result = sd_plane_right_triangle(p, b)
            else:  # index == 9
                result = sd_plane_square(p, b)

    return result

性能有那么一丝丝的提升。

简直要疯了,该怎么办啊。我很确信就是这部分if太多导致的,因为我只使用两种形状,然后把不用的if形状注释掉,性能就回来了。

用列表可以这样写呀

@ti.func
def sdf_all(index: ti.template(), p, b) -> float:
    result = MAX_DIS
    if 0 <= index < ti.static(len(SHAPE_FUNC)):
        result = SHAPE_FUNC[index](p, b)
    return result

不能这么写,会报错,因为SHAPE_FUNC是python作用域,而index是kernel作用域,直接SHAPE_FUNC[index]会报错。