参考文档为:网格弹簧质点系统模拟(Spring-Mass System by Euler Integration) - 算法小丑 - 博客园
直接按照上面的公式实现的,但是直接就爆炸了
import taichi as ti
ti.init(arch=ti.cpu)
# ti.init(arch=ti.cpu, debug=True)
max_num_particles = 500
dt = 1e-3
num_particles = ti.var(ti.i32, shape=()) # 质点数目
spring_stiffness = ti.var(ti.f32, shape=()) # 弹簧硬度
paused = ti.var(ti.i32, shape=()) # 是否暂停
damping = ti.var(ti.f32, shape=()) # 阻尼
particle_mass = 1 # 质点质量
bottom_y = 0.05 # 最底部y
x = ti.Vector(2, dt=ti.f32, shape=max_num_particles) # 质点位置
v = ti.Vector(2, dt=ti.f32, shape=max_num_particles) # 质点速度
A = ti.Matrix(2, 2, dt=ti.f32, shape=(max_num_particles, max_num_particles)) # 不知道
b = ti.Vector(2, dt=ti.f32, shape=max_num_particles) # 不知道
# rest_length[i, j] = 0 means i and j are not connected
rest_length = ti.var(ti.f32, shape=(max_num_particles, max_num_particles)) # 两个质点之间的弹簧原长
connection_radius = 0.15 # 两个质点连接的最大距离
gravity = [0, -9.8] # 重力
# jacobi iteration
max_iterate = 10
A = ti.var(dt=ti.f32, shape=(2 * max_num_particles, 2 * max_num_particles))
deltaV = ti.var(dt=ti.f32, shape=2 * max_num_particles)
new_deltaV = ti.var(dt=ti.f32, shape=2 * max_num_particles)
b = ti.var(dt=ti.f32, shape=2 * max_num_particles)
dfx = ti.var(dt=ti.f32, shape=(2 * max_num_particles, 2 * max_num_particles))
dfv = ti.var(dt=ti.f32, shape=(2 * max_num_particles, 2 * max_num_particles))
# 鼠标单击添加一个新的质点
@ti.kernel
def new_particle(pos_x: ti.f32, pos_y: ti.f32): # Taichi doesn't support using Matrices as kernel arguments yet
new_particle_id = num_particles[None] # 必须用None访问标量
x[new_particle_id] = [pos_x, pos_y]
v[new_particle_id] = [0, 0]
num_particles[None] += 1
# Connect with existing particles
for i in range(new_particle_id):
dist = (x[new_particle_id] - x[i]).norm()
if dist < connection_radius:
rest_length[i, new_particle_id] = dist
rest_length[new_particle_id, i] = dist
@ti.func
def init_solver():
for i in range(2 * max_num_particles):
A[i, i] = particle_mass
b[i] = 0
for j in range(2 * max_num_particles):
dfx[i, j] = 0
dfv[i, j] = 0
@ti.func
def iterate():
n = num_particles[None] # 只迭代左上角的小矩阵
for i in range(2 * n):
r = b[i]
for j in range(2 * n):
if i != j:
r -= A[i, j] * deltaV[j]
new_deltaV[i] = r / A[i, i]
for i in range(2 * n):
deltaV[i] = new_deltaV[i]
@ti.kernel
def substep():
n = num_particles[None]
# init solver
init_solver()
# compute dfx and dfv
for i in range(n):
total_force = ti.Vector(gravity) * particle_mass
for j in range(n):
if rest_length[i, j] != 0:
x_ij = x[i] - x[j]
v_ij = v[i] - v[j]
total_force += -damping[None] * x_ij.normalized() * v_ij * x_ij.normalized() # damping
total_force += -spring_stiffness[None] * (x_ij.norm() - rest_length[i, j]) * x_ij.normalized() # spring
if j >= i: # 只需要计算上三角
x_ji = x[j] - x[i]
v_ji = v[j] - v[i]
dfx[2 * i, 2 * i] = spring_stiffness[None] * ((x_ji[0] * x_ji[0] - 1) / x_ji.norm() *
(x_ji.norm() - rest_length[i, j]) -
(x_ji[0] * x_ji[0]))
dfx[2 * i + 1, 2 * i] = spring_stiffness[None] * ((x_ji[0] * x_ji[1] - 0) / x_ji.norm() *
(x_ji.norm() - rest_length[i, j]) -
(x_ji[0] * x_ji[1]))
dfx[2 * i, 2 * i + 1] = spring_stiffness[None] * ((x_ji[1] * x_ji[0] - 0) / x_ji.norm() *
(x_ji.norm() - rest_length[i, j]) -
(x_ji[1] * x_ji[0]))
dfx[2 * i + 1, 2 * i + 1] = spring_stiffness[None] * ((x_ji[1] * x_ji[1] - 1) / x_ji.norm() *
(x_ji.norm() - rest_length[i, j]) -
(x_ji[1] * x_ji[1]))
dfx[2 * j, 2 * j] = dfx[2 * i, 2 * i]
dfx[2 * j + 1, 2 * j] = dfx[2 * i + 1, 2 * i]
dfx[2 * j, 2 * j + 1] = dfx[2 * i, 2 * i + 1]
dfx[2 * j + 1, 2 * j + 1] = dfx[2 * i + 1, 2 * i + 1]
dfx[2 * i, 2 * j] = -dfx[2 * i, 2 * i]
dfx[2 * i + 1, 2 * j] = -dfx[2 * i + 1, 2 * i]
dfx[2 * i, 2 * j + 1] = -dfx[2 * i, 2 * i + 1]
dfx[2 * i + 1, 2 * j + 1] = -dfx[2 * i + 1, 2 * i + 1]
dfx[2 * j, 2 * i] = -dfx[2 * i, 2 * i]
dfx[2 * j + 1, 2 * i] = -dfx[2 * i + 1, 2 * i]
dfx[2 * j, 2 * i + 1] = -dfx[2 * i, 2 * i + 1]
dfx[2 * j + 1, 2 * i + 1] = -dfx[2 * i + 1, 2 * i + 1]
dfx[2 * i, 2 * i] += -damping[None] * ((x_ji[0] * x_ji[0] - 1) / x_ji.norm() *
((x_ji[0] * v_ji[0] + x_ji[1] * v_ji[1]) * 1 +
x_ji[0] * v_ji[0]))
dfx[2 * i + 1, 2 * i] += -damping[None] * ((x_ji[0] * x_ji[1] - 0) / x_ji.norm() *
((x_ji[0] * v_ji[0] + x_ji[1] * v_ji[1]) * 0 +
x_ji[0] * v_ji[1]))
dfx[2 * i, 2 * i + 1] += -damping[None] * ((x_ji[1] * x_ji[0] - 0) / x_ji.norm() *
((x_ji[0] * v_ji[0] + x_ji[1] * v_ji[1]) * 0 +
x_ji[1] * v_ji[0]))
dfx[2 * i + 1, 2 * i + 1] += -damping[None] * ((x_ji[1] * x_ji[1] - 1) / x_ji.norm() *
((x_ji[0] * v_ji[0] + x_ji[1] * v_ji[1]) * 1 +
x_ji[1] * v_ji[1]))
dfv[2 * i, 2 * i] = damping[None] * x_ji[0] * x_ji[0]
dfv[2 * i + 1, 2 * i] = damping[None] * x_ji[0] * x_ji[1]
dfv[2 * i, 2 * i + 1] = damping[None] * x_ji[1] * x_ji[0]
dfv[2 * i + 1, 2 * i + 1] = damping[None] * x_ji[1] * x_ji[1]
b[2 * i] = total_force[0]
b[2 * i + 1] = total_force[1]
# compute A and b
for i in range(2 * n):
for j in range(2 * n):
A[i, j] -= (dt * dfv[i, j] + dt * dt * dfx[i, j])
if j % 2 == 0:
b[i] += dt * dfx[i, j] * v[j][0]
else:
b[i] += dt * dfx[i, j] * v[j][1]
# solve
for i in range(max_iterate):
iterate()
# Compute new velocity
for i in range(n):
v[i] += deltaV[i] * dt
# Compute new position
for i in range(n):
x[i] += v[i] * dt
# Collide with ground
for i in range(n):
if x[i].y < bottom_y:
x[i].y = bottom_y
v[i].y = 0
spring_stiffness[None] = 10000
# spring_stiffness[None] = 1000000
damping[None] = 20
new_particle(0.3, 0.3)
new_particle(0.3, 0.4)
new_particle(0.4, 0.4)
gui = ti.GUI('Mass Spring System', res=(512, 512), background_color=0xdddddd)
while True:
for e in gui.get_events(ti.GUI.PRESS):
if e.key in [ti.GUI.ESCAPE, ti.GUI.EXIT]:
exit()
elif e.key == gui.SPACE:
paused[None] = not paused[None]
elif e.key == ti.GUI.LMB:
new_particle(e.pos[0], e.pos[1])
elif e.key == 'c':
num_particles[None] = 0
rest_length.fill(0)
elif e.key == 's':
if gui.is_pressed('Shift'):
spring_stiffness[None] /= 1.1
else:
spring_stiffness[None] *= 1.1
elif e.key == 'd':
if gui.is_pressed('Shift'):
damping[None] /= 1.1
else:
damping[None] *= 1.1
if not paused[None]:
for step in range(10):
substep()
X = x.to_numpy()
gui.circles(X[:num_particles[None]], color=0xffaa77, radius=5)
gui.line(begin=(0.0, bottom_y), end=(1.0, bottom_y), color=0x0, radius=1)
for i in range(num_particles[None]):
for j in range(i + 1, num_particles[None]):
if rest_length[i, j] != 0:
gui.line(begin=X[i], end=X[j], radius=2, color=0x445566)
gui.text(content=f'C: clear all; Space: pause', pos=(0, 0.95), color=0x0)
gui.text(content=f'S: Spring stiffness {spring_stiffness[None]:.1f}', pos=(0, 0.9), color=0x0)
gui.text(content=f'D: damping {damping[None]:.2f}', pos=(0, 0.85), color=0x0)
gui.text(content=f'Number of particles {num_particles[None]:.0f}', pos=(0, 0.80), color=0x0)
gui.show()