[分享] kernel中取出numpy切片(矩阵某一列)的临时解决方案

最近想要取出来矩阵A(大小为100x100的ndarray)中的某一列,尝试了下直接在kernel中用切片会报错。

import taichi as ti
import numpy as np
ti.init()
A = ti.ndarray(shape=(100,100), dtype=ti.f32)
@ti.kernel
def fill(A:ti.types.ndarray()):
    for i in range(100):
        for j in range(100):
            A[i,j] = i*100 + j
@ti.kernel
def test(A:ti.types.ndarray()):
    for i in range(100):
        col = A[:,i]
fill(A)
test(A)

报错

TaichiSyntaxError:
File "C:\Users\GRBJ200045\AppData\Local\Temp\ipykernel_75048\2181333545.py", line 17, in test:
        col = A[:,i]
              ^^^^^^
The type  do not support index of slice type

所以想了个临时解决办法,先在外面定义个ndarr然后在使用的时候用A填充过去。

import taichi as ti
import numpy as np

ti.init()
N = 5
M = 6
A = ti.ndarray(shape=(N,M), dtype=ti.f32)
col_j = ti.ndarray(shape=(A.shape[0]), dtype=ti.f32)

@ti.kernel
def fill(A: ti.types.ndarray()):
    for i in range(N):
        for j in range(M):
            A[i,j] = i*M + j

@ti.kernel
def test(A: ti.types.ndarray(), col_j: ti.types.ndarray()):
    get_col(A, 4, col_j)

@ti.func
def get_col(A: ti.types.ndarray(), j: ti.i32, ret: ti.types.ndarray()):
    for i in range(A.shape[0]):
        ret[i] = A[i,j]

fill(A)
test(A, col_j)
print(A.to_numpy())
print(col_j.to_numpy())

输出

[Taichi] Starting on arch=x64
[[ 0.  1.  2.  3.  4.  5.]
 [ 6.  7.  8.  9. 10. 11.]
 [12. 13. 14. 15. 16. 17.]
 [18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29.]]
[ 4. 10. 16. 22. 28.]

为了以后有遇到相同问题的同学的方便,在此分享一下上面的临时之策。

当然希望taichi团队能推出支持slice的功能,毕竟还是经常会遇到的需求…

1 个赞