遇到一个问题,不知道怎么操作才可以使用taichi加速,因为会用到很多次np.argsort和np.where,而且切片以后拼接又有着一定的先后顺序,感觉并行也有难度。还是说我一开始方向就错了。
数据是一些(x, y)坐标大约2000*1000,目的是把他们分成若干行排序的状态,左右相邻两个数据的y坐标偏差很小可以通过数值区分,但是离得远了就不一定可以区分了。
我把坐标点先根据x坐标大小argsort再切片循环,就相当于很多竖条。然后再给这些竖条进行y坐标大小的argsort再切片循环,就是一个个方格里处理,方格的局部数据根据x、y的大小接近程度就能区分行列。但是我这么实现超级慢…
import numpy as np
X, Y = 0, 1
class Line:
def __init__(self, init_data):
"""
这个init_data已经是排好序的,这边不处理排序了
:param init_data:
"""
self.line_data = init_data
self.left = self.line_data[0]
self.right = self.line_data[-1]
def add(self, new_line_data):
"""
只往右边加,因为现在从左往右读的数据
:param new_line_data:
:return:
"""
self.line_data = np.vstack([self.line_data, new_line_data])
self.right = new_line_data[-1]
class Panel:
LINE_Y_TOL = 1
LINE_X_TOL = 50
def __init__(self):
self.lines = []
self.temp_lines = []
def add_line_data(self, new_line_data):
"""添加新的line数据,如果可以拼接到已有line数据末尾则拼接,不能拼接说明是新的line,加入到缓冲区,在一次循环之后加入到正式的line中"""
flag = False
for line_instance in self.lines:
if abs(line_instance.right[Y] - new_line_data[0][Y]) < self.LINE_Y_TOL and abs(line_instance.right[X] - new_line_data[0][X]) < self.LINE_X_TOL:
line_instance.add(new_line_data)
flag = True
break
if not flag:
self.temp_lines.append(Line(new_line_data))
def line_flush(self):
"""把缓存区的line添加到正式的line中去"""
self.lines.extend(self.temp_lines)
self.temp_lines = []
def GetPanelData(data, depart_num=20):
p = Panel()
data = data[1:] # data[0]是背景的中心点需要去掉
data_sort_by_x = data[np.argsort(data[:, X])] # sort1:把全部数据按x坐标排序
step = data.shape[0] // depart_num
for i in range(depart_num):
one_step = data_sort_by_x[step * i:] if i == depart_num - 1 else data_sort_by_x[step * i:step * (i + 1)]
real_step_length = len(one_step)
one_step_sort_by_y = one_step[np.argsort(one_step[:, Y])] # sort2: 把一大列数据按照y坐标排序
delta_y = one_step_sort_by_y[1:real_step_length, Y] - one_step_sort_by_y[0:real_step_length - 1, Y]
line_threshold = max(np.max(delta_y) / 4, 1)
last_idx = 0
for idx in np.where(delta_y > line_threshold)[0]:
line = one_step_sort_by_y[last_idx:idx + 1]
line_sorted = line[np.argsort(line[:, X])] # sort3:这边得到一小行数据,把这一小行数据按照x坐标排序,得到有序的一行
p.add_line_data(line_sorted)
last_idx = idx + 1
line = one_step_sort_by_y[last_idx:]
line_sorted = line[np.argsort(line[:, X])]
p.add_line_data(line_sorted)
p.line_flush()
return p
def ProduceFakeData(height, width, gap_y, gap_x, bias):
y_1D = np.arange(0, height * gap_y, gap_y)
x_1D = np.arange(0, width * gap_x, gap_x)
x_2D, y_2D = np.meshgrid(x_1D, y_1D)
line_noise = np.cumsum(np.random.random((height, width)) * bias, axis=1)
return np.stack((x_2D, y_2D + line_noise), axis=2).reshape((-1, 2))
if __name__ == "__main__":
fake_data = ProduceFakeData(height=2000, width=1000, gap_y=2, gap_x=1, bias=0.005)
np.random.shuffle(fake_data)
panel = GetPanelData(data=fake_data)
print(len(panel.lines))