最近想要取出来矩阵A(大小为100x100的ndarray)中的某一列,尝试了下直接在kernel中用切片会报错。
import taichi as ti
import numpy as np
ti.init()
A = ti.ndarray(shape=(100,100), dtype=ti.f32)
@ti.kernel
def fill(A:ti.types.ndarray()):
for i in range(100):
for j in range(100):
A[i,j] = i*100 + j
@ti.kernel
def test(A:ti.types.ndarray()):
for i in range(100):
col = A[:,i]
fill(A)
test(A)
报错
TaichiSyntaxError:
File "C:\Users\GRBJ200045\AppData\Local\Temp\ipykernel_75048\2181333545.py", line 17, in test:
col = A[:,i]
^^^^^^
The type do not support index of slice type
所以想了个临时解决办法,先在外面定义个ndarr然后在使用的时候用A填充过去。
import taichi as ti
import numpy as np
ti.init()
N = 5
M = 6
A = ti.ndarray(shape=(N,M), dtype=ti.f32)
col_j = ti.ndarray(shape=(A.shape[0]), dtype=ti.f32)
@ti.kernel
def fill(A: ti.types.ndarray()):
for i in range(N):
for j in range(M):
A[i,j] = i*M + j
@ti.kernel
def test(A: ti.types.ndarray(), col_j: ti.types.ndarray()):
get_col(A, 4, col_j)
@ti.func
def get_col(A: ti.types.ndarray(), j: ti.i32, ret: ti.types.ndarray()):
for i in range(A.shape[0]):
ret[i] = A[i,j]
fill(A)
test(A, col_j)
print(A.to_numpy())
print(col_j.to_numpy())
输出
[Taichi] Starting on arch=x64
[[ 0. 1. 2. 3. 4. 5.]
[ 6. 7. 8. 9. 10. 11.]
[12. 13. 14. 15. 16. 17.]
[18. 19. 20. 21. 22. 23.]
[24. 25. 26. 27. 28. 29.]]
[ 4. 10. 16. 22. 28.]
为了以后有遇到相同问题的同学的方便,在此分享一下上面的临时之策。
当然希望taichi团队能推出支持slice的功能,毕竟还是经常会遇到的需求…