使用 exec 函数生成 taichi kernel 无法执行

例如,我想做一个计算得分的函数,其结果是一系列子函数的求和,我想通过一个配置文件比较灵活地指定,运行哪些子函数来进行求和。我尝试了通过numba,和taichi来实现jit加速。

因为 numba 和 taichi 中都无法使用一个 value 为函数的字典,因此考虑根据配置文件,动态地生成这个函数的源代码,然后使用 exec 来执行,以避免在jit中出现查找对应函数的操作。

我写了一个简单的 case 来进行测试,通过一个简单的数组,例如 [1, 1, 2],来依次调用 func1 和 func2
目前,numba 的代码已经跑通了,如下:

import numba as nb


@nb.njit(nogil=True)
def func1(a):
    print(a + 1)


@nb.njit(nogil=True)
def func2(a):
    print(a + 2)


functions = {1: func1, 2: func2}

CODE_FORMAT = """
@nb.njit(nogil=True)
def _show(base: int):
    {}
"""


def build_show(indexes):
    func_call = "\n    ".join([f"{functions[i].__name__}(base)" for i in indexes])
    code = CODE_FORMAT.format(func_call)
    print(code)
    exec(code)
    return eval("_show")


show = build_show([1, 1, 2])
show(0)

其命令行的输出结果如下:


@nb.njit(nogil=True)
def _show(base: int):
    func1(base)
    func1(base)
    func2(base)

1
1
2

可以看到,已经根据配置文件([1, 1, 2])生成了正确的代码,并且成功运行了

然后,我尝试使用 taichi来写一个类似的代码:

import taichi as ti
ti.init(ti.cpu)


@ti.func
def func1(a: int):
    print(a + 1)


@ti.func
def func2(a: int):
    print(a + 2)


functions = {1: func1, 2: func2}

CODE_FORMAT = """
@ti.kernel
def _show(base: int):
    {}
"""


def build_show(indexes):
    func_call = "\n    ".join([f"{functions[i].__name__}(base)" for i in indexes])
    code = CODE_FORMAT.format(func_call)
    print(code)
    exec(code)
    return eval("_show")


show = build_show([1, 1, 2])
show(0)

发现运行时报了 AttributeError: 'Import' object has no attribute 'args'
其运行结果如下:

[Taichi] version 1.1.2, llvm 10.0.0, commit f25cf4a2, win, python 3.8.8
[Taichi] Starting on arch=x64

@ti.kernel
def _show(base: int):
    func1(base)
    func1(base)
    func2(base)

Traceback (most recent call last):
  File "E:\code\taichi_simple_fluid_solver\ti_dict2.py", line 33, in <module>
    show(0)
  File "C:\ProgramData\Anaconda3\lib\site-packages\taichi\lang\kernel_impl.py", line 918, in wrapped
    return primal(*args, **kwargs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\taichi\lang\kernel_impl.py", line 844, in __call__
    key = self.ensure_compiled(*args)
  File "C:\ProgramData\Anaconda3\lib\site-packages\taichi\lang\kernel_impl.py", line 819, in ensure_compiled
    self.materialize(key=key, args=args, arg_features=arg_features)
  File "C:\ProgramData\Anaconda3\lib\site-packages\taichi\lang\kernel_impl.py", line 507, in materialize
    tree, ctx = _get_tree_and_ctx(
  File "C:\ProgramData\Anaconda3\lib\site-packages\taichi\lang\kernel_impl.py", line 119, in _get_tree_and_ctx
    for i, arg in enumerate(func_body.args.args):
AttributeError: 'Import' object has no attribute 'args'

进程已结束,退出代码1

从输出结果来看,生成的代码字符串是没有问题的,然后我进行了对照实验:

def build_show(indexes):
    @ti.kernel
    def _show(base: int):
        func1(base)
        func1(base)
        func2(base)
    return eval("_show")

将上文的 taichi 代码中,将 build_show 函数进行替换,直接执行命令行中输出的代码,发现是可以正常运行的。
输出结果为:

[Taichi] version 1.1.2, llvm 10.0.0, commit f25cf4a2, win, python 3.8.8
[Taichi] Starting on arch=x64
1
1
2

我现在没有找到,使用 exec 函数 定义 taichi kernel 无法运行,问题到底出在那里

hi,确实没办法这样运行,但是在讨论实现之前,我想知道为什么要在kernel里这样找func呢?其实有很多办法可以做的更简单呀

比如把func1和func2都弄成Taichi kernel,然后host根据参数来选择具体的func
或者_show收一个参数,然后用if else就行,一般这种并不会带来过于显著的开销,实在担心这个开销也可以把参数声明成ti.template(),会jit出多个kernel来执行。

所以我想了解一下你的需求,我们一起看看是否有更好的写法

是对一系列候选项目进行筛选,类似于推荐系统中的召回这种操作。
一开始实现的时候,是使用的一系列 filter 进行嵌套,而且针对不同的策略,有不同的filter组合,所以不能写死。示例代码如下:

filters = {1: filter1, 2: filter2, 3: filter3}
data = [d1, d2, d3, d4]
filter_index = [1, 3]
for i in index:
    data = filters[i](data)  # 对于python而言,这里只是构造迭代器,没有计算
data = list(data). # 遍历嵌套的迭代器,实际执行所有的计算

这种只对单个filter做jit的话,因为有header开销,加速效果不大,不到2倍。如果对整个循环进行JIT,也不是很好JIT。

后来为了更好的筛选效果,改为了让每条数据,计算所有选中的filter,然后加权算得分,最后根据得分进行筛选。这样确实也更好做并行化了,给每个filter单独写一个 for 循环做 jit 就行了。

但是第一种情况如何进行 JIT 因为没想出来,平时没事就会想一下怎么写。这次有点思路,想试试通过生成函数代码字符串的方式,来动态的生成这个函数(虽然是可读性极差的“奇技淫巧”:innocent:),避免查找函数的问题(numba 中 jit加速的函数无法作为字典的value)。numba确实过了,一测试,taichi没过,而且报的错很奇怪,就发论坛了。

这种只对单个filter做jit的话,因为有header开销,加速效果不大,不到2倍。如果对整个循环进行JIT,也不是很好JIT。

这一块我没太看懂,这个header开销指的是什么呢?

如果是python dict开销很大,反正你的key是int,直接用list来写?

或者如果你对整个函数的开销非常非常敏感,可以考虑用AOT提前compile好,然后直接调用就行

可能是我表述不清楚。
之前提过一个issue,得到的回复是

For very small kernels, it only runs for very short time so the launch overhead dominates the performance

如果只将 单个 filter 写成 taichi kernel,for 循环中每一次调用都有一个 laugh overhead,就导致加速效果有限。
而且外层的for循环是可以并行的,所以能直接加速外层的for循环,既能避免 大量的 laugh overhead,又可以并行化处理,加速效果就会比较明显

例如,如果需要的filter是固定的,直接写死在函数中,可以这么写:

@ti.kernel
def foo(data: ti.template(), mask: ti.template()):
    for i in data:
        d = ti.static(data[i])
        mask[i] = filter1(d) and filter2(d) and filter3(d) 

而如果是只把 filter1,filter2 等写成taichi kernel,在python scope的循环中调用,就会慢很多了。

我遇到的问题主要就是,如果调用的filter是需要调用者的输入灵活变换的,就不是很清楚,应该怎么实现,保证仍能JIT。

我又去仔细看了一遍taichi的官方文档,或许编译时展开可以解决这个问题?

使用 ti.static 对编译时分支展开

理解了!但是有个问题,data[i]事实上并不是static的,所以这里无法在编译期展开。

那么为什么不用if else来对data[i]做判断呢?会有特别多的filter吗?

d = ti.static(data[i])

这里我是看那个 起别名的语法糖,想少打几个字来着,不是想做展开。
没有在实际代码里尝试过这种写法,直接写成 filter1(data[i]) 这种,就没歧义了吧