用太极实现3D形态学膨胀操作

太极的struct-for loop真妙啊,太优美了!

import taichi as ti

@ti.kernel
def taichi_dilate(out: ti.types.ndarray(), input: ti.types.ndarray(), element_struct: ti.types.ndarray()):
    """ 3D Binary Dilation
        Args:
        out: Output tensor. Should be zeros(Batch,Channel,H,W,D).
        input: Input mask (Batch,Channel,H,W,D).
        element_struct: Structure to dilate (H,W,D). Shape must be odd.
     """
    B, C, H, W, D = out.shape[0], out.shape[1], out.shape[2], out.shape[3], out.shape[4]
    I, J, K = element_struct.shape[0], element_struct.shape[1], element_struct.shape[2]
    Imid, Jmid, Kmid = (I - 1) // 2, (J - 1) // 2, (K - 1) // 2
    for b, c, h, w, d, i, j, k in ti.ndrange(B, C, H, W, D, I, J, K):
        if 0 <= h + i - Imid < H - 1 and 0 <= w + j - Jmid < W - 1 and 0 <= d + k - Kmid < D - 1:
            ti.atomic_or(out[b, c, h, w, d],
                         input[b, c, h + i - Imid, w + j - Jmid, d + k - Kmid] * element_struct[i, j, k])
1 个赞

可以贴一个运行效果不? :slight_smile:

发现上面的写法是错的:smiling_face_with_tear:,原来不能并行改写同一个值,还得再加一个循环。感谢大佬提醒哈哈哈。

def taichi_dilate(out: ti.types.ndarray(), input: ti.types.ndarray(), element_struct: ti.types.ndarray()):
    """ 3D Binary Dilation
        Args:
        out: Output tensor. Should be zeros(Batch,Channel,H,W,D).
        input: Input mask (Batch,Channel,H,W,D).
        element_struct: Structure to dilate (H,W,D). Shape must be odd.
     """
    B, C, H, W, D = out.shape[0], out.shape[1], out.shape[2], out.shape[3], out.shape[4]
    I, J, K = element_struct.shape[0], element_struct.shape[1], element_struct.shape[2]
    Imid, Jmid, Kmid = (I - 1) // 2, (J - 1) // 2, (K - 1) // 2

    for b, c, h, w, d in ti.ndrange(B, C, H, W, D):
        for i, j, k in ti.ndrange(I,J,K):
            if 0 <= h + i - Imid < H - 1 and 0 <= w + j - Jmid < W - 1 and 0 <= d + k - Kmid < D - 1:
                ti.atomic_or(out[b, c, h, w, d],
                            input[b, c, h + i - Imid, w + j - Jmid, d + k - Kmid] * element_struct[i, j, k])

下面是效果: