Homework1 共轭梯度迭代的实现,有点问题,求大神解答

代码改编自老师的雅克比迭代

import taichi as ti
import random

ti.init()

n = 20

A = ti.var(dt=ti.f32, shape=(n, n))
x = ti.var(dt=ti.f32, shape=n)
new_x = ti.var(dt=ti.f32, shape=n)
b = ti.var(dt=ti.f32, shape=n)

r = ti.var(dt=ti.f32, shape=n)
new_r = ti.var(dt=ti.f32, shape=n)
p = ti.var(dt=ti.f32, shape=n)

Ap = ti.var(dt=ti.f32, shape=n)


def init():
    r.from_numpy(b.to_numpy() - A.to_numpy().dot(x.to_numpy()))
    new_r.from_numpy(r.to_numpy())
    p.from_numpy(r.to_numpy())


@ti.kernel
def iterate():
    # Ap
    for i in range(n):
        Ap[i] = 0

    for i in range(n):
        for j in range(n):
            Ap[i] += A[i, j] * p[j]

    a_1 = 0  # 分子
    for i in range(n):
        a_1 += r[i] * r[i]

    a_2 = 0  # 分母
    for i in range(n):
        a_2 += p[i] * Ap[i]

    a = a_1 / a_2

    for i in range(n):
        new_x[i] = x[i] + a * p[i]

    # TODO r is small

    for i in range(n):
        new_r[i] = r[i] - a * Ap[i]

    b_1 = 0  # 分子
    for i in range(n):
        b_1 += new_r[i] * new_r[i]

    b_2 = 0  # 分母
    for i in range(n):
        b_2 += r[i] * r[i]

    b = b_1 / b_2

    for i in range(n):
        p[i] = new_r[i] + b * p[i]

    for i in range(n):
        x[i] = new_x[i]

    for i in range(n):
        r[i] = new_r[i]


@ti.kernel
def residual() -> ti.f32:
    res = 0.0

    for i in range(n):
        r = b[i] * 1.0
        for j in range(n):
            r -= A[i, j] * x[j]
        res += r * r

    return res


for i in range(n):
    for j in range(n):
        A[i, j] = random.random() - 0.5

    A[i, i] += n * 0.1

    b[i] = random.random() * 100


for i in range(100):
    iterate()
    print(f'iter {i}, residual={residual():0.10f}')

for i in range(n):
    lhs = 0.0
    for j in range(n):
        lhs += A[i, j] * x[j]
    assert abs(lhs - b[i]) < 1e-4

太极里面的向量乘法和矩阵向量的乘法使了好几次都不行,报看不懂的bug,还望大神们解答

CG法是有限制条件的,要求A为对称正定阵。

感谢,那对于不是对称正定的矩阵,有什么转化的方法可解吗,我查一下

如果A不满足条件,那么可以方程两边先左乘A的转置,变成A’Ax=A’b
当然还有很多其他的预条件算法,可以去了解一下

感谢

太极里面的向量乘法和矩阵向量的乘法使了好几次都不行,报看不懂的bug,还望大神们解答

What bug? Could you paste it out?

乘法这块问题已经解决了,感谢大佬们