Hi,
Never mind. I think I just shouldn’t overwrite any variables.
import taichi as ti
@ti.data_oriented
class ProductTest:
def __init__(self):
self.a=ti.var(ti.f32,shape=())
self.b=ti.var(ti.f32,shape=())
self.c=ti.var(ti.f32,shape=())
self.product=ti.var(ti.f32,shape=())
self.product1=ti.var(ti.f32,shape=())
self.sum=ti.var(ti.f32,shape=())
ti.root.lazy_grad()
self.a[None]=1
self.b[None]=2
self.c[None]=3
self.product[None]=0
self.product1[None]=0
self.sum[None]=0
@ti.kernel
def multiply(self,a:ti.template(),b:ti.template(),result:ti.template()):
result[None]=a[None]*b[None]
@ti.kernel
def add(self,a:ti.template(),b:ti.template(),result:ti.template()):
result[None]=a[None]+b[None]
def test(self):
with ti.Tape(self.sum): # compute a*b+a*c
self.multiply(self.a,self.b,self.sum)
self.multiply(self.a,self.c,self.product)
self.add(self.sum,self.product,self.sum)
print('product',self.sum[None],'a.grad',self.a.grad[None],'b.grad',self.b.grad[None],'c.grad',self.c.grad[None])
if __name__=='__main__':
ti.init(ti.gpu)
test=ProductTest()
test.test()