自动微分过程中,为什么某个变量+=一个0向量之后能grad出来,没有这个步骤却grad不出来

使用的是taichi 1.1.4版本
我在自动微分过程中发现了一个很奇怪的事情。

我将pos,vel,acc,force,damping,spring_y全部设置为了needs_grad=True了。

如果力的计算过程中没有那一行被标记的计算步骤,那么返回来的grad值除了loss.grad和pos.grad有数值之外,其他所有变量的grad值全部是0。

但如果有了那一步骤,哪怕实质上force += 了一个全是0的向量,那么其他变量的grad值都有数。

我不能理解为什么会有这种情况发生。所以来问一问大佬这是这么回事

@ti.kernel
def simulation(t):
  for i,j in ti.ndrange(num,num):
    #以下计算力
    p1=pos[t-1]
    v1=vel[t-1]
    ###force[t] += - vel[t-1]* 0###此步为疑惑点
    for n in range(spring_data.shape[2]):
      ... ...
      force[t] += -spring_y[None]*(dist/spring_rest_length -1)*d
      force[t] += -dashpot_damping[None] * dv.dot(d) * d

        #以下更新加速度,速度和位置
  if ((i==0 and j==0) or (i==35 and j==0)):#此步为筛选移动点
      acc[t]=force[t]/mass
      vel[t]=trajectory_data[t]
      pos[t]=pos[t-1]+vel[t]*dt
    else:
      acc[t]=force[t]/mass
      vel[t]=(vel[t-1]+acc[t]*dt)*ti.exp(-dt*drag_damping[None])
      pos[t]=pos[t-1]+vel[t]*dt
      collide_with_table(t,i,j)

@ti.kernel
def loss_n():
   loss=(pos[..,..,..] - target[..,..,..]).norm

hello~ @WTC , 能发一个小一些能跑的复现script么?可能是求导chain 被覆盖了,但是得看一下代码~

你好,我把简化后的代码写在下面了

import taichi as ti
ti.init(arch=ti.cuda,device_memory_fraction=0.5)

dt=1e-3
step=1600
v_number=36
v_res=v_number-1
cloth_length=4.0

vec=lambda : ti.Vector.field(3,dtype=ti.f32)
scale=lambda: ti.field(dtype=ti.f32)

pos=vec()
force=vec()
vel=vec()

stiffness=scale()
drag_damping=scale()
loss_n=scale()

ti.root.dense(ti.ijk,(step,v_number,v_number)).place(pos,force,vel)
ti.root.place(loss_n,stiffness,drag_damping)
ti.root.lazy_grad()

spring_date=ti.Vector.field(3,dtype=ti.f32,shape=(v_number,v_number,4))
traj=ti.Vector.field(3,dtype=ti.f32,shape=(2,step))

@ti.kernel
def cloth_init():
    stiffness[None]=1000.0
    drag_damping[None]=1.0
    for t,i,j in pos:
        pos[t,i,j]=ti.Vector([1.0+i*(cloth_length/v_res),1.0,1.0+j*(cloth_length/v_res)])
        vel[t,i,j]=ti.Vector([0.0,0.0,0.0])
        force[t,i,j]=ti.Vector([0.0,-9.8,0.0]) * 1.0
    spring_init()

@ti.func
def get_x(n:ti.i32) ->ti.i32:
    ax=0
    if (n==0):
        ax=1
    elif (n==2):
        ax=-1
    else:
        ax=0
    return ax

@ti.func
def get_y(n:ti.i32)->ti.i32:
    ax=0
    if (n==1) :
        ax=-1
    elif (n==3):
        ax= 1
    else:
        ax = 0
    return ax

@ti.func
def spring_init():
    for i,j,k in spring_date:
        spring_coord = ti.Vector([get_x(k),get_y(k)])
        coord_neigh = spring_coord + ti.Vector([i,j])
        if (coord_neigh.x>=0) and (coord_neigh.x<=v_res) and (coord_neigh.y>=0) and (coord_neigh.y<=v_res): 
            spring_date[i,j,k]=ti.Vector([cloth_length/v_res,coord_neigh.x,coord_neigh.y])
        else:
            spring_date[i,j,k]=ti.Vector([0.0,0.0,0.0])

@ti.func
def spring_force(t:ti.i32,i:ti.i32,j:ti.i32):
    p1=pos[t-1,i,j]
    ######################################
    force[t,i,j] += - vel[t-1,i,j] * 0.0
    ######################################
    for n in range(spring_date.shape[2]):
        spring_length=spring_date[i,j,n][0]
        if spring_length != 0.0 :
            x=int(spring_date[i,j,n][1])
            y=int(spring_date[i,j,n][2])
            p2=pos[t-1,x,y]
            dp=p1-p2
            force[t,i,j] += -stiffness[None]*(dp.norm()/spring_length -1)*dp.normalized()

@ti.kernel
def simulation(t:ti.i32):
    for i,j in ti.ndrange(v_number,v_number):
        spring_force(t,i,j)
        if ((i == 0 and j == 0 ) or (i == v_res and j == 0 )): 
            pass
        else:
            vel[t,i,j]=(vel[t-1,i,j]+(force[t,i,j]/1.0)*dt)*ti.exp(-dt * drag_damping[None])
            pos[t,i,j]=pos[t-1,i,j]+vel[t,i,j]*dt

@ti.kernel
def compute_loss(t:ti.i32):
    loss_n[None]=(pos[t,int(v_res/2),0]-pos[t,int(v_res/2),v_res]).norm()

cloth_init()
with ti.ad.Tape(loss=loss_n,clear_gradients=True):
    for j in range(1,step):
        simulation(j)
    compute_loss(step-1)      
print("loss=",loss_n[None])
print("vel",vel.grad[1595,int(v_res/2),3])
print("force",force.grad[1595,int(v_res/2),3])
print("pos",pos.grad[1595,int(v_res/2),3])
print("stiffness",stiffness.grad[None])
print("drag_damping",drag_damping.grad[None])

已经复现了这个问题,确实是个bug,可以先用添加atomic add的版本作为workaround。目前正在修了 [autodiff] Fix missing global load stmt in independent blocks by erizmr · Pull Request #6662 · taichi-dev/taichi · GitHub

1 Like

谢谢,这个方法甚至帮我解决了在误差向后传导过程中,误差不下降的问题。以前我的误差总是不下降,改成atomic add之后就可以正常下降。

2 Likes