# 利用taichi实现linear bvh

``````from utils import *
from scene import Scene
from sort import *

BvhNode = ti.types.struct(bound=Bound, left=ti.i32, right=ti.i32, primitive=ti.i32, parent=ti.i32)

@ti.func
def is_leaf(node: BvhNode):
return node.primitive != -1

@ti.func
def left_3_shift(x: ti.u32):
# from pbrt
if x == (1 << 10):
x -= 1
x = (x | (x << 16)) & 0b00000011000000000000000011111111
# x = ---- --98 ---- ---- ---- ---- 7654 3210
x = (x | (x << 8)) & 0b00000011000000001111000000001111
# x = ---- --98 ---- ---- 7654 ---- ---- 3210
x = (x | (x << 4)) & 0b00000011000011000011000011000011
# x = ---- --98 ---- 76-- --54 ---- 32-- --10
x = (x | (x << 2)) & 0b00001001001001001001001001001001
# x = ---- 9--8 --7- -6-- 5--4 --3- -2-- 1--0
return x

@ti.func
def generate_morton_code(c: float3, ext: Bound):
offset = c - ext.min
total_ext = ext.max - ext.min
x = offset * 1024 / total_ext
return left_3_shift(ti.floor(x[0], dtype=ti.i32)) | left_3_shift(ti.math.floor(x[1], dtype=ti.i32)) << 1 \
| left_3_shift(ti.math.floor(x[2], dtype=ti.i32)) << 2

@ti.data_oriented
class Bvh:
def __init__(self, scene_ext, primitive_count, primitives, sorter="bi_gpu"):
self.primitives = primitives
self.internal_node_count = primitive_count - 1
self.primitive_count = primitive_count

self.bvh = BvhNode.field(shape=[self.primitive_count + self.internal_node_count])
self.codes = ti.field(dtype=ti.i32, shape=[self.primitive_count])
self.scene_ext = scene_ext
self.sorter = sorter
self.atomic_counter = ti.field(dtype=ti.i32, shape=[self.internal_node_count])

self.cook_bvh_gpu()
self.atomic_counter = None

@ti.kernel
def generate_morton_code(self, primitives: ti.template()):
for i in range(self.primitive_count):
bound = primitives[i].bound
centroid = (bound.max + bound.min) * 0.5

offset = centroid - self.scene_ext.min
total_ext = self.scene_ext.max - self.scene_ext.min
x = offset * 1024 / total_ext

self.codes[i] = left_3_shift(ti.floor(x[0], dtype=ti.i32)) | left_3_shift(ti.math.floor(x[1], dtype=ti.i32)) << 1 \
| left_3_shift(ti.math.floor(x[2], dtype=ti.i32)) << 2

@ti.func
def delta(self, n1: ti.i32, n2: ti.i32):
rv = -1
if 0 <= n2 < self.primitive_count:
c1 = self.codes[n1]
c2 = self.codes[n2]
v = 0
if c1 == c2:
c1 = n1
c2 = n2
v = 31

c = c1 ^ c2
rv = 31 - ti.math.floor(ti.math.log(c) / 0.69314, dtype=ti.i32) + v

return rv

@ti.kernel
def build_bvh(self):
# from https://research.nvidia.com/sites/default/files/publications/karras2012hpg_paper.pdf
for i in range(self.internal_node_count):
d = int(ti.math.sign(self.delta(i, i + 1) - self.delta(i, i - 1)))

delta_min = self.delta(i, i - d)

l_max = 2
while self.delta(i, i + l_max * d) > delta_min:
l_max = l_max * 2
l = 0
t = l_max // 2
while t >= 1:
if self.delta(i, i + (l + t) * d) > delta_min:
l = l + t
t = t // 2
j = i + l * d

delta_node = self.delta(i, j)
s = 0
t = l
div = 2
while t > 1:
# from floor to ceiling
t = ti.ceil(l / div, dtype=ti.i32)
if self.delta(i, i + (s + t) * d) > delta_node:
s = s + t
div *= 2
gama = i + s * d + ti.min(d, 0)

left = 0
right = 0
if ti.min(i, j) == gama:
left = gama + self.internal_node_count
else:
left = gama
if ti.max(i, j) == gama + 1:
right = gama + 1 + self.internal_node_count
else:
right = gama + 1

self.bvh[i].left = left
self.bvh[i].right = right
self.bvh[i].primitive = -1

self.bvh[left].parent = i
self.bvh[right].parent = i

@ti.kernel
def assign_bound(self):

"""
as the original paper says:
'Each thread starts from one leaf node and walks up the tree using parent pointers that we record during radix
tree construction. We track how many threads have visited each internal node using atomic counters—the first
thread terminates immediately while the second one gets to process the node. This way, each node is processed by
"""

for i in range(self.primitive_count):
idx = i + self.internal_node_count
idx = self.bvh[idx].parent
while idx >= 0:
if counter == 0:
idx = -1
else:
left = self.bvh[idx].left
right = self.bvh[idx].right
bound = merge_bound(self.bvh[left].bound, self.bvh[right].bound)
self.bvh[idx].bound = bound
idx = self.bvh[idx].parent

@ti.kernel
def set_primitive_idx(self, primitives: ti.template()):
for i in range(self.primitive_count):
self.bvh[i + self.internal_node_count].primitive = i
self.bvh[i + self.internal_node_count].bound = primitives[i].bound

def sort(self):
output_primitive = Primitive.field(shape=[self.primitive_count])
elif self.sorter == "bi_gpu":
BiSorterGpu(self.primitives, output_primitive, self.primitive_count, self.codes)
self.primitives = output_primitive

def cook_bvh_gpu(self):
self.generate_morton_code(self.primitives)
self.sort()
self.bvh[0].parent = -1
self.build_bvh()
self.set_primitive_idx(self.primitives)
self.assign_bound()

self.codes = None

``````
