Invalid constant scalar data type: <class 'taichi.lang.any_array.AnyArray'>

import numpy as np

import taichi as ti
ti.init(arch = ti.cpu)

data = np.array([[1, 1, 1, 0, 0, 0],
                 [0, 1, 1, 1, 0, 0],
                 [0, 1, 1, 1, 0, 0],
                 [1, 1, 1, 0, 0, 0],
                 [1, 1, 1, 1, 1, 0]])

@ti.kernel
def calculate_d(row:ti.types.ndarray(),arr:ti.types.ndarray()):
    a = 0
    for i in ti.grouped(arr):
        if (row == i).all():
            a +=1
    return a

def calculate_p(nn):
    nn_copy = np.copy(nn).astype(np.float64) 

    for i in range(nn.shape[1]):
        n_del = np.delete(nn,i,axis = 1)
        for j in range(nn.shape[0]):
            numerator = calculate_d(n_del[j],n_del)
            denominator = calculate_d(nn[j],nn)
            p = denominator/numerator
            # print(i,'---',j,'---',denominator,'---',numerator,'---',p)
            nn_copy[j,i] = p

    return nn_copy

calculate_p(data)

上面是我的代码,下面是报错执行之后会报错,但是我不太清楚应该怎么调整,是判断两个列表相等错了么?在taichi里面应该怎么判断……

File “/tmp/ipykernel_2952/1084500900.py”, line 16, in calculate_d:
if (row == i).all():
Invalid constant scalar data type: <class ‘taichi.lang.any_array.AnyArray’>

Hi @maweijiao, 在Taichi scope里写的代码和普通Python代码是有很多不同的地方。在Taichi scope你写的代码是Taichi编译器进行编译处理的。

你的代码里有两个问题:

  1. (row == i).all() Taichi暂时是无法解析的。
  2. ti.kernel如果有返回值,需要在函数声明的时候指定类型,比如def calculate_d(row:ti.types.ndarray(),arr:ti.types.ndarray()) -> ti.i32:
import numpy as np

import taichi as ti
ti.init(arch = ti.cpu)

data = np.array([[1, 1, 1, 0, 0, 0],
                 [0, 1, 1, 1, 0, 0],
                 [0, 1, 1, 1, 0, 0],
                 [1, 1, 1, 0, 0, 0],
                 [1, 1, 1, 1, 1, 0]])

@ti.kernel
def calculate_d(row:ti.types.ndarray(),arr:ti.types.ndarray()) -> ti.i32:
    a = 0
    for i in row:
        if row[i] == i:
            a +=1
    return a

def calculate_p(nn):
    nn_copy = np.copy(nn).astype(np.float64)

    for i in range(nn.shape[1]):
        n_del = np.delete(nn,i,axis = 1)
        for j in range(nn.shape[0]):
            numerator = calculate_d(n_del[j],n_del)
            denominator = calculate_d(nn[j],nn)
            p = denominator/numerator
            # print(i,'---',j,'---',denominator,'---',numerator,'---',p)
            nn_copy[j,i] = p

    return nn_copy

calculate_p(data)