太极的struct-for loop真妙啊,太优美了!
import taichi as ti
@ti.kernel
def taichi_dilate(out: ti.types.ndarray(), input: ti.types.ndarray(), element_struct: ti.types.ndarray()):
""" 3D Binary Dilation
Args:
out: Output tensor. Should be zeros(Batch,Channel,H,W,D).
input: Input mask (Batch,Channel,H,W,D).
element_struct: Structure to dilate (H,W,D). Shape must be odd.
"""
B, C, H, W, D = out.shape[0], out.shape[1], out.shape[2], out.shape[3], out.shape[4]
I, J, K = element_struct.shape[0], element_struct.shape[1], element_struct.shape[2]
Imid, Jmid, Kmid = (I - 1) // 2, (J - 1) // 2, (K - 1) // 2
for b, c, h, w, d, i, j, k in ti.ndrange(B, C, H, W, D, I, J, K):
if 0 <= h + i - Imid < H - 1 and 0 <= w + j - Jmid < W - 1 and 0 <= d + k - Kmid < D - 1:
ti.atomic_or(out[b, c, h, w, d],
input[b, c, h + i - Imid, w + j - Jmid, d + k - Kmid] * element_struct[i, j, k])