大家好,最近在学习diffmpm_checkpointing这个例子,对其中用@ti.ad.grad_replaced和@ti.ad.grad_for(substep)来替换gradient function的操作有点困惑。
在Taichi文档中,这两个decorator是这样使用的:
@ti.kernel
def multiply(a: ti.float32):
for I in ti.grouped(x):
y[I] = x[I] * a
@ti.kernel
def multiply_grad(a: ti.float32):
for I in ti.grouped(x):
x.grad[I] = y.grad[I] / a
@ti.ad.grad_replaced
def foo(a):
multiply(a)
@ti.ad.grad_for(foo)
def foo_grad(a):
multiply_grad(a)
我的理解是,在multiply_grad中里对x.grad和y.grad之间的关系进行了重新定义。
在diffmpm_checkpointing这个例子中,是这样使用两个decorator的:
@ti.ad.grad_replaced
def substep(s):
clear_grid()
p2g(s)
grid_op()
g2p(s)
@ti.ad.grad_for(substep)
def substep_grad(s):
clear_grid()
p2g(s)
grid_op()
g2p.grad(s)
grid_op.grad()
p2g.grad(s)
在这个例子中,我没有看到像文档中那样明显的重新定义操作,所以不太明白是怎样cutomize gradient function的。谢谢大家的解答!