新手求助 关于在python scope中交换field而taichi scope中未同步更新的问题

我尝试实现Kogge-Stone网络计算数组前缀和。其代码类似:

@ti.data_orianted
class KoggeStone:
    def __init__(self):
         # double buffer swap after every pass
         self.input = ti.i32.field()
         self.output = ti.i32. field()
    
    @ti.kernel
    def scan(self):
          # 每个pass打印相同值,预期打印值为上一个pass的输出
          print(self.input) 
          for i in range(self.input):
               self.output[i] = do_something(self.input, i)

    def add(self):
        # for multiple passes
        for i in range(n_pass):
              self.scan()
              # swap double buffer
              tmp = self.input
              self.input = self.output
              self.output = tmp

我发现每个pass中input和output 未成功交换
但当我将scan改写为用ti.template()实现时,能达到预期结果:

@ti.data_orianted
class KoggeStone:
    def __init__(self):
         # double buffer swap after every pass
         self.input = ti.i32.field()
         self.output = ti.i32. field()
    
    @ti.kernel
    def scan(self, input: ti.template(), output: ti.template()):
          # 达到预期结果
          print(input) 
          for i in range(input):
               output[i] = do_something(input, i)

    def add(self):
        # for multiple passes
        for i in range(n_pass):
              self.scan(self.input, self.output)
              # swap double buffer
              tmp = self.input
              self.input = self.output
              self.output = tmp

我的问题有两个

  1. 是什么导致了这一现象?
  2. template在每次调用时都会实例化一次,我后面这种实现方式会给性能带来多大额外开销?有更好的double buffer实现方式吗?

你好:wave: 导致这个现象的原因在于不加参数的时候 kernel 只会实例化一次,而有参数之后会实例化多次,展开来说就是:

  1. 如果参数是空的,scan() 只知道编译时 inputoutput 的地址,无法感知到后面 inputoutput 的地址已经发生了交换,所以在 add() 里一直在重复做同样的事
  2. 如果参数列表里有 ti.template(),kernel 就会实例化成多个 kernel,这样就达到预期效果了

使用 ti.template() 其实并不会将其指向的对象完全拷贝,而是只是进行指针操作,所以应该没多大额外开销。这边请 @jimyang 确认一下~

参考:

感谢解答 ,现在弄懂为什么了:pray:

补充一下,使用 ti.template() 由于会编译多个 kernel 实例,所以编译时间的确会增加,这个是个额外开销。