Data transfer with pytorch is the bottleneck

Hi, I want to use taichi with pytorch, but I find data transfer with pytorch takes most of the time. Do anyone have suggestions?

Here is a simple script of my code, data transfer takes 4.79s, and taichi computing only takes 2.20s

import taichi as ti
ti.init(arch=ti.cuda, kernel_profiler=True)
import torch
import time

b, c, n, k = 16, 128, 2048, 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
fmap1, fmap2 = torch.rand([b,c,n]), torch.rand([b,c,n])
fmap1, fmap2 = fmap1.to(device), fmap2.to(device)
indice = compute_indice(fmap1, fmap2)


iter_num = 1000

fmap1_ti = ti.field(ti.f32)
fmap2_ti = ti.field(ti.f32)
indice_ti = ti.field(ti.i32)
output_ti = ti.field(ti.f32)

ti.root.dense(ti.ijk, (b,c,n)).place(fmap1_ti)
ti.root.dense(ti.ijk, (b,c,n)).place(fmap2_ti)
ti.root.dense(ti.ijk, (b,k,n)).place(indice_ti)  
ti.root.dense(ti.ijk, (b,k,n)).place(output_ti)  



@ti.kernel
def clean_output():
    # set output_ti to zeros

@ti.kernel
def taichi_compute():
    # do something

t2 = 0
t2_copy = 0

for idx in range(iter_num):

    t2_start = time.time()
    fmap1_ti.from_torch(fmap1)
    fmap2_ti.from_torch(fmap2)
    indice_ti.from_torch(indice)
    ti.sync()
    t2_copy += (time.time() - t2_start)

    t2_start = time.time()
    clean_output()
    taichi_compute()
    ti.sync()
    t2 += (time.time() - t2_start)

    t2_start = time.time()    
    output = output_ti.to_torch(device=device)
    t2_copy += (time.time() - t2_start)

print(t2)
print(t2_copy)