关于Autodiff中替换gradient function的疑问

大家好,最近在学习diffmpm_checkpointing这个例子,对其中用@ti.ad.grad_replaced和@ti.ad.grad_for(substep)来替换gradient function的操作有点困惑。

在Taichi文档中,这两个decorator是这样使用的:

@ti.kernel
def multiply(a: ti.float32):
    for I in ti.grouped(x):
        y[I] = x[I] * a

@ti.kernel
def multiply_grad(a: ti.float32):
    for I in ti.grouped(x):
        x.grad[I] = y.grad[I] / a

@ti.ad.grad_replaced
def foo(a):
    multiply(a)

@ti.ad.grad_for(foo)
def foo_grad(a):
    multiply_grad(a)

我的理解是,在multiply_grad中里对x.grad和y.grad之间的关系进行了重新定义。

在diffmpm_checkpointing这个例子中,是这样使用两个decorator的:

@ti.ad.grad_replaced
def substep(s):
    clear_grid()
    p2g(s)
    grid_op()
    g2p(s)

@ti.ad.grad_for(substep)
def substep_grad(s):
    clear_grid()
    p2g(s)
    grid_op()

    g2p.grad(s)
    grid_op.grad()
    p2g.grad(s)

在这个例子中,我没有看到像文档中那样明显的重新定义操作,所以不太明白是怎样cutomize gradient function的。谢谢大家的解答!

Hi @JinliBot , 被ti.ad.grad_for修饰的函数就是原函数的求导函数,文档里那个例子因为比较简单,所以完全手推导数,直接写出了x.grad和y.grad的函数式; 而下面diffmpm_checkpointing例子里的substep_grad中,没有完全去手推导数,而是借助了用autodiff生成的求导函数,并将其按照与原函数相反的顺序执行 (reverse mode/反向传播)来求导数

    g2p.grad(s)
    grid_op.grad()
    p2g.grad(s)

而这部分则是服务于checkpointing的re-computation,即为了省内存,把一些求导要用到的中间变量在原函数计算中不作记录,而是通过在求导函数中重新计算来获得

    clear_grid()
    p2g(s)
    grid_op()
1 个赞