从是否能在串行的循环中执行并行的循环?继续讨论:
在许多并行算法中,算法需要“分层”完成,如FFT以及上述的SCAN算法;每一层完成时需要所有线程(或线程组/block/group)进行同步,防止下一层中使用了上一层还未修改的数据。
最简单的方法将各层独立为一个kernel call,另一种方法是在并行循环层中执行串行循环
@ti.func
def log2_int(n:ti.u32):
res = 0
if n&ti.u32(0xffff0000):
res += 16
n >>= 16
if n&ti.u32(0x0000ff00):
res += 8
n >>= 8
if n&ti.u32(0x000000f0):
res += 4
n >>= 4
if n&ti.u32(0x0000000c):
res += 2
n >>= 2
if n&ti.u32(0x00000002):
res += 1
n >>= 1
return res
@ti.kernel
def prefix_sum_kernel(x:ti.template(),y:ti.template()):
n = x.shape[0]
for i in x:
y[i] = x[i]
total_step = log2_int(n)
threads = n // 2
for i in range(threads): #parallel
#loop
for t in range(total_step):
src = ((i>>t)<<(t+1)) + (1<<t) - 1
dst = src + 1 + (i & ((1<<t) - 1))
y[dst] += y[src]
# sync all threads
ti.sync()
按照 ti.sync()
的描述,其起到的应该是一个全局同步的作用
Blocks the calling thread until all the previously launched Taichi kernels have completed.
但是经过实际的测试(CUDA on NVIDIA GeForce MX450)后发现,在1024长度的数据下只有前512个计算结果是正确的,也就是说超出前512个线程外的线程并没有同步,甚至去除了 ti.sync()
后仍然得到同样的结果, ti.sync()
似乎不起到任何作用,可能是我没正确使用?
因为全局同步本身是一件很麻烦的事情,cuda中也没有能提供直接的全局同步方法
需要通过Cooperative Groups进行同步
不知道应该如何在taichi中进行类似于Cooperative Groups的同步的方法
一个粗糙的方法是使用线程组中的一个线程进行原子计数来实现线程组间(全局)同步
一个简单的实现如下
@ti.kernel
def prefix_sum_kernel(x:ti.template(),y:ti.template()):
n = x.shape[0]
for i in x:
y[i] = x[i]
total_step = log2_int(n)
threads = n // 2
# block_dim = 32
block_count = threads // 32
mutex = 0
ti.loop_config(block_dim=32)
for i in range(threads): #parallel
thread_gid = ti.global_thread_idx()
block_id = thread_gid // 32
thread_id = thread_gid % 32 # ti.simt.block.thread_idx() unavailable for cuda
#loop
for t in range(total_step):
src = ((i>>t)<<(t+1)) + (1<<t) - 1
dst = src + 1 + (i & ((1<<t) - 1))
y[dst] += y[src]
# sync all threads
if thread_id == 0:
ti.atomic_add(mutex,1)
while mutex < block_count * (t+1):
pass
ti.simt.block.sync()
但是奇怪的是,这会导致死锁
打印检测发现
if thread_id == 0:
ti.atomic_add(mutex,1)
while mutex < block_count * (t+1):
pass
print(f'passed mutex = {mutex} block_count = {block_count * (t+1)} block_id = {block_id}')
passed mutex = 16 block_count = 16 block_id = 0
passed mutex = 16 block_count = 16 block_id = 14
只有两个block(32x2 一个warp)结束了循环,其余都卡在了循环里?
但是如果在while循环中打印信息,这个死锁神奇地消失了
while mutex < block_count * (t+1):
print(f'mutex = {mutex} block_count = {block_count * (t+1)} block_id = {block_id}')
pass
这里到底发生了什么?猜测是如果没有print(对mutex变量的访问)时,不同warp间的缓存不同步了?但是如果不是print而只是读取mutex变量依然会死锁。这是taichi的bug吗?以及要如何解决?