The two image is from dataset.
import taichi as ti
import numpy as np
from PIL import Image
nd_left = np.asarray(Image.open("tsukuba-imL.png").getdata())
nd_right = np.asarray(Image.open("tsukuba-imR.png").getdata())
h, w = nd_left.shape
NUM_ITER = 40
DATA = 0
UP = 1
DOWN = 2
LEFT = 3
RIGHT = 4
NUM_LABEL = 16
WIN_RADIUS = 2
LAMBDA = 20
SMOOTHNESS_TRUNC = 2
ti.init(debug=True)
ti_left = ti.field(ti.i32, shape=(h, w))
ti_right = ti.field(ti.i32, shape=(h, w))
grid_pixel_messages = ti.field(ti.i32, shape=(h, w, 5, NUM_LABEL))
grid_pixel_best_assignment = ti.field(ti.u8, shape=(h, w))
label_temp = ti.field(ti.i32, shape=NUM_LABEL)
min_val = ti.field(ti.i64, shape=NUM_LABEL)
msg_tmp = ti.field(ti.i64, shape=NUM_LABEL)
@ti.func
def data_cost_stereo(i, j, k):
cost = 0
for wh, ww in ti.ndrange((i-WIN_RADIUS, i+WIN_RADIUS+1),(j-WIN_RADIUS, j+WIN_RADIUS+1)):
cost += ti.abs(ti_left[wh, ww]-ti_right[wh, ww-k])
cost = cost//(WIN_RADIUS+1)**2
return cost
@ti.func
def smoothness_cost_l1(i, j):
return LAMBDA*ti.min(ti.abs(i-j), SMOOTHNESS_TRUNC)
@ti.kernel
def init_data_cost():
for i,j,m,n in ti.ndrange((0, h), (0, w), (DATA, RIGHT), (0, NUM_LABEL)):
grid_pixel_messages[i, j, m, n] = 0
for i,j,k in ti.ndrange((NUM_LABEL, h-NUM_LABEL), (NUM_LABEL, w-NUM_LABEL), (0, NUM_LABEL)):
grid_pixel_messages[i, j, DATA, k] = data_cost_stereo(i, j, k)
@ti.kernel
def send_msg(x: ti.u16, y: ti.u16, direction: ti.u8):
new_msg = ti.Vector([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0])
for l_to in range(NUM_LABEL):
min_val = 0x7fffffff
for l_from in range(NUM_LABEL):
msg = 0
msg += smoothness_cost_l1(l_to, l_from)
msg += grid_pixel_messages[x, y, DATA, l_from]
if direction != LEFT:
msg += grid_pixel_messages[x, y, LEFT, l_from]
if direction != RIGHT:
msg += grid_pixel_messages[x, y, RIGHT, l_from]
if direction != UP:
msg += grid_pixel_messages[x, y, UP, l_from]
if direction != DOWN:
msg += grid_pixel_messages[x, y, DOWN, l_from]
# min_val = ti.min(min_val, msg)
# new_msg[l_to] = min_val
# # update
# for idx in range(NUM_LABEL):
# if direction==LEFT:
# grid_pixel_messages[x, y-1, RIGHT, idx] = new_msg_l_to[idx]
# elif direction==RIGHT:
# grid_pixel_messages[x, y+1, LEFT, idx] = new_msg_l_to[idx]
# elif direction==UP:
# grid_pixel_messages[x-1, y, DOWN, idx] = new_msg_l_to[idx]
# elif direction==DOWN:
# grid_pixel_messages[x+1, y, UP, idx] = new_msg_l_to[idx]
def BP(direction):
if direction==LEFT:
for i in range(0, h):
for j in range(w-1,0,-1):
send_msg(i, j ,direction)
elif direction==RIGHT:
for i in range(0, h):
for j in range(0, w-1):
send_msg(i, j ,direction)
elif direction==UP:
for j in range(0, w):
for i in range(h-1, 0, -1):
send_msg(i, j ,direction)
elif direction==DOWN:
for j in range(0, w):
for i in range(0, h-1):
send_msg(i, j ,direction)
ti_left.from_numpy(nd_left)
ti_right.from_numpy(nd_right)
init_data_cost()
# for it in range(NUM_ITER):
BP(RIGHT)
# BP(LEFT)
# BP(UP)
# BP(DOWN)