例如,我想做一个计算得分的函数,其结果是一系列子函数的求和,我想通过一个配置文件比较灵活地指定,运行哪些子函数来进行求和。我尝试了通过numba,和taichi来实现jit加速。
因为 numba 和 taichi 中都无法使用一个 value 为函数的字典,因此考虑根据配置文件,动态地生成这个函数的源代码,然后使用 exec 来执行,以避免在jit中出现查找对应函数的操作。
我写了一个简单的 case 来进行测试,通过一个简单的数组,例如 [1, 1, 2],来依次调用 func1 和 func2
目前,numba 的代码已经跑通了,如下:
import numba as nb
@nb.njit(nogil=True)
def func1(a):
print(a + 1)
@nb.njit(nogil=True)
def func2(a):
print(a + 2)
functions = {1: func1, 2: func2}
CODE_FORMAT = """
@nb.njit(nogil=True)
def _show(base: int):
{}
"""
def build_show(indexes):
func_call = "\n ".join([f"{functions[i].__name__}(base)" for i in indexes])
code = CODE_FORMAT.format(func_call)
print(code)
exec(code)
return eval("_show")
show = build_show([1, 1, 2])
show(0)
其命令行的输出结果如下:
@nb.njit(nogil=True)
def _show(base: int):
func1(base)
func1(base)
func2(base)
1
1
2
可以看到,已经根据配置文件([1, 1, 2])生成了正确的代码,并且成功运行了
然后,我尝试使用 taichi来写一个类似的代码:
import taichi as ti
ti.init(ti.cpu)
@ti.func
def func1(a: int):
print(a + 1)
@ti.func
def func2(a: int):
print(a + 2)
functions = {1: func1, 2: func2}
CODE_FORMAT = """
@ti.kernel
def _show(base: int):
{}
"""
def build_show(indexes):
func_call = "\n ".join([f"{functions[i].__name__}(base)" for i in indexes])
code = CODE_FORMAT.format(func_call)
print(code)
exec(code)
return eval("_show")
show = build_show([1, 1, 2])
show(0)
发现运行时报了 AttributeError: 'Import' object has no attribute 'args'
其运行结果如下:
[Taichi] version 1.1.2, llvm 10.0.0, commit f25cf4a2, win, python 3.8.8
[Taichi] Starting on arch=x64
@ti.kernel
def _show(base: int):
func1(base)
func1(base)
func2(base)
Traceback (most recent call last):
File "E:\code\taichi_simple_fluid_solver\ti_dict2.py", line 33, in <module>
show(0)
File "C:\ProgramData\Anaconda3\lib\site-packages\taichi\lang\kernel_impl.py", line 918, in wrapped
return primal(*args, **kwargs)
File "C:\ProgramData\Anaconda3\lib\site-packages\taichi\lang\kernel_impl.py", line 844, in __call__
key = self.ensure_compiled(*args)
File "C:\ProgramData\Anaconda3\lib\site-packages\taichi\lang\kernel_impl.py", line 819, in ensure_compiled
self.materialize(key=key, args=args, arg_features=arg_features)
File "C:\ProgramData\Anaconda3\lib\site-packages\taichi\lang\kernel_impl.py", line 507, in materialize
tree, ctx = _get_tree_and_ctx(
File "C:\ProgramData\Anaconda3\lib\site-packages\taichi\lang\kernel_impl.py", line 119, in _get_tree_and_ctx
for i, arg in enumerate(func_body.args.args):
AttributeError: 'Import' object has no attribute 'args'
进程已结束,退出代码1
从输出结果来看,生成的代码字符串是没有问题的,然后我进行了对照实验:
def build_show(indexes):
@ti.kernel
def _show(base: int):
func1(base)
func1(base)
func2(base)
return eval("_show")
将上文的 taichi 代码中,将 build_show 函数进行替换,直接执行命令行中输出的代码,发现是可以正常运行的。
输出结果为:
[Taichi] version 1.1.2, llvm 10.0.0, commit f25cf4a2, win, python 3.8.8
[Taichi] Starting on arch=x64
1
1
2
我现在没有找到,使用 exec 函数 定义 taichi kernel 无法运行,问题到底出在那里