Taichi的Autodiff对于需要做累加计算的函数,是否一定要对for循环做ti.static展开?


import taichi as ti
import numpy as np

ti.init(arch=ti.gpu)

N = 3
t = 0.2
g = ti.field(dtype=ti.f32, shape=(N), needs_grad=True)
v = ti.field(dtype=ti.f32, shape=(N), needs_grad=True)
x = ti.field(dtype=ti.f32, shape=(N), needs_grad=True)

loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True)

@ti.kernel
def allocate():
g[0] = 0.3
g[1] = 0.3
g[2] = 0.3

@ti.kernel
def compute_loss():
loss[None] = 0.1x[0]+0.1x[1]+0.1*x[2]

@ti.kernel
def compute_v():
for i in range(N):
vi=0.0
# for j in ti.static(range(N)):
for j in range(N):
vi+=g[j]*t
v[i] = vi

@ti.kernel
def compute_x():
for i in range(N):
xi=0.0
for j in range(N):
xi += v[j]*t
x[i] = xi

def forward():
compute_v()
compute_x()

allocate()

with ti.ad.Tape(loss=loss):
forward()
compute_loss()

print(g.grad)

以上是测试代码,损失函数是关于自变量g的一个计算时间复杂度为O(n^2)的函数。如果使用自动微分,计算v(compute_v())时需要把内层for循环展开计算,否则计算的导数是错误的。

不知道是我的写法有问题,还是taichi对于累加运算一定要用ti.static展开for循环。我看到官方的diffmpm.py涉及到累加操作时也用到了ti.static,但是展开的次数很少所以没有影响计算效率。在我们的例子中,N的规模比较大,展开的计算量是无法承受的。