关于在kernel中使用pytorch模型的问题

我希望在taichi的kernel中使用一个预训练好的pytorch模型,在将模型的输入转化为tensor的时候出现了这个问题,似乎是不能将一个kernel里面的变量转化为tensor,因此也无法使用这个模型,请问有什么可以解决这个问题的办法吗?

@ti.kernel
    def Func(self):
        a = [1, 2, 3]
        b = torch.tensor(a)
Traceback (most recent call last):
  File "E:\SSF\main.py", line 38, in <module>
    main(running)
  File "E:\SSF\main.py", line 32, in main
    nsf.run()
  File "E:\SSF\MachingLearning\NormalModel\NeuralSSF.py", line 130, in run
    self.Func()
  File "C:\Users\VRG716\.conda\envs\SSF\lib\site-packages\taichi\lang\kernel_impl.py", line 1035, in __call__
    raise type(e)("\n" + str(e)) from None
taichi.lang.exception.TaichiCompilationError: 
File "E:\SSF\MachingLearning\NormalModel\NeuralSSF.py", line 75, in Func:
        b = torch.tensor(a)
            ^^^^^^^^^^^^^^^
Traceback (most recent call last):
  File "C:\Users\VRG716\.conda\envs\SSF\lib\site-packages\taichi\lang\ast\ast_transformer_utils.py", line 27, in __call__
    return method(ctx, node)
  File "C:\Users\VRG716\.conda\envs\SSF\lib\site-packages\taichi\lang\ast\ast_transformer.py", line 581, in build_Call
    node.ptr = func(*args, **keywords)
RuntimeError: Could not infer dtype of Expr

框架不兼容,数据结构不能直接转换。可以试试用numpy架桥,但是效率会很低。

目前应该是没有一个很好的解决方案,我的解决办法是将pytorch训练好的网络参数导出,在taichi里面使用matrix预读入,之后自己写一个网络forward的流程。
缺点是一层网络的参数往往大于32,编译非常耗时。我的网络中最大的一层是2000+参数,每次编译需要约几分钟的时间,当一层网络的参数继续增大时,大概5000个参数左右,几十分钟也编译不好。
如果需要运行小规模的网络,我的这种思路可行,更大规模的网络就必须考虑脱离taichi框架使用了,非常遗憾! :face_with_spiral_eyes: