在调用kernel前print tensor会导致CUDA error

import torch
import taichi as ti
import time
ti.init(arch=ti.cuda)

class Timer:
    def __enter__(self):
        self.start = time.time()

    def __exit__(self, *exc_info):
        print(f'Elapsed {time.time() - self.start} seconds')

@ti.kernel
def taichi_oper_forward(x: ti.types.ndarray(field_dim=2), out: ti.types.ndarray(field_dim=2)):
    for i,j in x:
        out[i,j] = x[i,j] * x[i,j] * 3

@ti.kernel
def taichi_oper_backward(x: ti.types.ndarray(field_dim=2), 
                         gout: ti.types.ndarray(field_dim=2), 
                         gx: ti.types.ndarray(field_dim=2)):
    for i,j in x:
        gx[i,j] = gout[i,j]*6*x[i,j]
        
class my_oper(torch.autograd.Function):
    @staticmethod
    def forward(ctx,x):
        ctx.x = x
        out = torch.empty_like(x, memory_format=torch.contiguous_format)
        taichi_oper_forward(x, out)
        ti.sync()
        return out
    
    @staticmethod
    def backward(ctx,gout):
        gx = torch.empty_like(ctx.x, memory_format=torch.contiguous_format)
        taichi_oper_backward(ctx.x,gout.contiguous(),gx)
        ti.sync()
        return gx


x = torch.ones(5000,5000,device="cuda",requires_grad=True,dtype=torch.float32)
print(x)
y = x+2
print(y)

with Timer():
    z = my_oper.apply(y.cuda())
out = z.mean()
with Timer():
    out.backward()
    
print(x.grad)

直接执行会出现如下错误:

    out = z.mean()
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

假如把print(x)、print(y)删除,代码可正常运行。

[Taichi] version 1.1.2, llvm 10.0.0, commit f25cf4a2, linux, python 3.7.11
[Taichi] Starting on arch=cuda
Elapsed 0.00536799430847168 seconds
Elapsed 0.0071604251861572266 seconds
tensor([[7.2000e-07, 7.2000e-07, 7.2000e-07,  ..., 7.2000e-07, 7.2000e-07,
         7.2000e-07],
        [7.2000e-07, 7.2000e-07, 7.2000e-07,  ..., 7.2000e-07, 7.2000e-07,
         7.2000e-07],
        [7.2000e-07, 7.2000e-07, 7.2000e-07,  ..., 7.2000e-07, 7.2000e-07,
         7.2000e-07],
        ...,
        [7.2000e-07, 7.2000e-07, 7.2000e-07,  ..., 7.2000e-07, 7.2000e-07,
         7.2000e-07],
        [7.2000e-07, 7.2000e-07, 7.2000e-07,  ..., 7.2000e-07, 7.2000e-07,
         7.2000e-07],
        [7.2000e-07, 7.2000e-07, 7.2000e-07,  ..., 7.2000e-07, 7.2000e-07,
         7.2000e-07]], device='cuda:0')

Hi @xjun ! 欢迎来到Taichi社区
可以先尝试更新到最新版本的Taichi试试?