import torch
import taichi as ti
import time
ti.init(arch=ti.cuda)
class Timer:
def __enter__(self):
self.start = time.time()
def __exit__(self, *exc_info):
print(f'Elapsed {time.time() - self.start} seconds')
@ti.kernel
def taichi_oper_forward(x: ti.types.ndarray(field_dim=2), out: ti.types.ndarray(field_dim=2)):
for i,j in x:
out[i,j] = x[i,j] * x[i,j] * 3
@ti.kernel
def taichi_oper_backward(x: ti.types.ndarray(field_dim=2),
gout: ti.types.ndarray(field_dim=2),
gx: ti.types.ndarray(field_dim=2)):
for i,j in x:
gx[i,j] = gout[i,j]*6*x[i,j]
class my_oper(torch.autograd.Function):
@staticmethod
def forward(ctx,x):
ctx.x = x
out = torch.empty_like(x, memory_format=torch.contiguous_format)
taichi_oper_forward(x, out)
ti.sync()
return out
@staticmethod
def backward(ctx,gout):
gx = torch.empty_like(ctx.x, memory_format=torch.contiguous_format)
taichi_oper_backward(ctx.x,gout.contiguous(),gx)
ti.sync()
return gx
x = torch.ones(5000,5000,device="cuda",requires_grad=True,dtype=torch.float32)
print(x)
y = x+2
print(y)
with Timer():
z = my_oper.apply(y.cuda())
out = z.mean()
with Timer():
out.backward()
print(x.grad)
直接执行会出现如下错误:
out = z.mean()
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
假如把print(x)、print(y)删除,代码可正常运行。
[Taichi] version 1.1.2, llvm 10.0.0, commit f25cf4a2, linux, python 3.7.11
[Taichi] Starting on arch=cuda
Elapsed 0.00536799430847168 seconds
Elapsed 0.0071604251861572266 seconds
tensor([[7.2000e-07, 7.2000e-07, 7.2000e-07, ..., 7.2000e-07, 7.2000e-07,
7.2000e-07],
[7.2000e-07, 7.2000e-07, 7.2000e-07, ..., 7.2000e-07, 7.2000e-07,
7.2000e-07],
[7.2000e-07, 7.2000e-07, 7.2000e-07, ..., 7.2000e-07, 7.2000e-07,
7.2000e-07],
...,
[7.2000e-07, 7.2000e-07, 7.2000e-07, ..., 7.2000e-07, 7.2000e-07,
7.2000e-07],
[7.2000e-07, 7.2000e-07, 7.2000e-07, ..., 7.2000e-07, 7.2000e-07,
7.2000e-07],
[7.2000e-07, 7.2000e-07, 7.2000e-07, ..., 7.2000e-07, 7.2000e-07,
7.2000e-07]], device='cuda:0')