碰到有个问题使用了多次np.argsort不知道怎么移植到taichi

遇到一个问题,不知道怎么操作才可以使用taichi加速,因为会用到很多次np.argsort和np.where,而且切片以后拼接又有着一定的先后顺序,感觉并行也有难度。还是说我一开始方向就错了。

数据是一些(x, y)坐标大约2000*1000,目的是把他们分成若干行排序的状态,左右相邻两个数据的y坐标偏差很小可以通过数值区分,但是离得远了就不一定可以区分了。

我把坐标点先根据x坐标大小argsort再切片循环,就相当于很多竖条。然后再给这些竖条进行y坐标大小的argsort再切片循环,就是一个个方格里处理,方格的局部数据根据x、y的大小接近程度就能区分行列。但是我这么实现超级慢…

import numpy as np

X, Y = 0, 1


class Line:
    def __init__(self, init_data):
        """
        这个init_data已经是排好序的,这边不处理排序了
        :param init_data:
        """
        self.line_data = init_data
        self.left = self.line_data[0]
        self.right = self.line_data[-1]

    def add(self, new_line_data):
        """
        只往右边加,因为现在从左往右读的数据
        :param new_line_data:
        :return:
        """
        self.line_data = np.vstack([self.line_data, new_line_data])
        self.right = new_line_data[-1]


class Panel:
    LINE_Y_TOL = 1
    LINE_X_TOL = 50

    def __init__(self):
        self.lines = []
        self.temp_lines = []

    def add_line_data(self, new_line_data):
        """添加新的line数据,如果可以拼接到已有line数据末尾则拼接,不能拼接说明是新的line,加入到缓冲区,在一次循环之后加入到正式的line中"""
        flag = False
        for line_instance in self.lines:
            if abs(line_instance.right[Y] - new_line_data[0][Y]) < self.LINE_Y_TOL and abs(line_instance.right[X] - new_line_data[0][X]) < self.LINE_X_TOL:
                line_instance.add(new_line_data)
                flag = True
                break

        if not flag:
            self.temp_lines.append(Line(new_line_data))

    def line_flush(self):
        """把缓存区的line添加到正式的line中去"""
        self.lines.extend(self.temp_lines)
        self.temp_lines = []


def GetPanelData(data, depart_num=20):
    p = Panel()
    data = data[1:]  # data[0]是背景的中心点需要去掉
    data_sort_by_x = data[np.argsort(data[:, X])]  # sort1:把全部数据按x坐标排序

    step = data.shape[0] // depart_num
    for i in range(depart_num):
        one_step = data_sort_by_x[step * i:] if i == depart_num - 1 else data_sort_by_x[step * i:step * (i + 1)]
        real_step_length = len(one_step)
        one_step_sort_by_y = one_step[np.argsort(one_step[:, Y])]  # sort2: 把一大列数据按照y坐标排序
        delta_y = one_step_sort_by_y[1:real_step_length, Y] - one_step_sort_by_y[0:real_step_length - 1, Y]
        line_threshold = max(np.max(delta_y) / 4, 1)
        last_idx = 0
        for idx in np.where(delta_y > line_threshold)[0]:
            line = one_step_sort_by_y[last_idx:idx + 1]
            line_sorted = line[np.argsort(line[:, X])]  # sort3:这边得到一小行数据,把这一小行数据按照x坐标排序,得到有序的一行
            p.add_line_data(line_sorted)
            last_idx = idx + 1
        line = one_step_sort_by_y[last_idx:]
        line_sorted = line[np.argsort(line[:, X])]
        p.add_line_data(line_sorted)
        p.line_flush()
    return p


def ProduceFakeData(height, width, gap_y, gap_x, bias):
    y_1D = np.arange(0, height * gap_y, gap_y)
    x_1D = np.arange(0, width * gap_x, gap_x)
    x_2D, y_2D = np.meshgrid(x_1D, y_1D)
    line_noise = np.cumsum(np.random.random((height, width)) * bias, axis=1)
    return np.stack((x_2D, y_2D + line_noise), axis=2).reshape((-1, 2))


if __name__ == "__main__":
    fake_data = ProduceFakeData(height=2000, width=1000, gap_y=2, gap_x=1, bias=0.005)
    np.random.shuffle(fake_data)
    panel = GetPanelData(data=fake_data)
    print(len(panel.lines))

我建议把脚本修改成一个可以运行的,比如data用随机数生成好,里面的变量p也说清楚,这样应该好测试一点。

好,我改一下

修改了