请问 differential programing 不支持 data_oriented 吗?

我把文档中的例子改为 data_oriented 版本后无法运行

import taichi as ti
ti.init(arch=ti.cuda)

@ti.data_oriented
class Simulation():
    def __init__(self) -> None:
        self.N = 8
        self.dt = 1e-5
        self.x = ti.Vector.field(2, dtype=ti.f32, shape=self.N, needs_grad=True)  # particle positions
        self.v = ti.Vector.field(2, dtype=ti.f32, shape=self.N)  # particle velocities
        self.U = ti.field(dtype=ti.f32, shape=(), needs_grad=True)  # potential energy
        self.init()

    @ti.kernel
    def compute_U(self):
        N = self.N
        for i, j in ti.ndrange(N, N):
            r = self.x[i] - self.x[j]
            # r.norm(1e-3) is equivalent to ti.sqrt(r.norm()**2 + 1e-3)
            # This is to prevent 1/0 error which can cause wrong derivative
            self.U[None] += -1 / r.norm(1e-3)  # U += -1 / |r|


    @ti.kernel
    def advance(self):
        dt = self.dt
        for i in self.x:
            self.v[i] += dt * -self.x.grad[i]  # dv/dt = -dU/dx
        for i in self.x:
            self.x[i] += dt * self.v[i]  # dx/dt = v


    def substep(self):
        with ti.Tape(loss=self.U):
            # Kernel invocations in this scope will later contribute to partial derivatives of
            # U with respect to input variables such as x.
            self.compute_U(
            )  # The tape will automatically compute dU/dx and save the results in x.grad
        self.advance()


    @ti.kernel
    def init(self):
        for i in self.x:
            self.x[i] = [ti.random(), ti.random()]


simu = Simulation()
gui = ti.GUI('Autodiff gravity')
while gui.running:
    for i in range(50):
        simu.substep()
    gui.circles(simu.x.to_numpy(), radius=3)
    gui.show()

报错信息:

[Taichi] version 0.8.4, llvm 10.0.0, commit 895881b5, win, python 3.8.11
[I 02/24/22 10:59:46.078 6104] [shell.py:_shell_pop_print@34] Graphical python shell detected, using wrapped sys.stdout
[Taichi] Starting on arch=cuda
[E 02/24/22 10:59:54.428 6104] [reverse_segments.cpp:taichi::lang::irpass::reverse_segments@72] Invalid program input for autodiff. Please check the documentation for the "Kernel Simplicity Rule":
https://docs.taichi.graphics/lang/articles/advanced/differentiable_programming#kernel-simplicity-rule


***********************************
* Taichi Compiler Stack Traceback *
***********************************
0x7ff8d328076a: taichi::print_traceback in taichi_core.pyd
0x7ff8d314f319: PyInit_taichi_core in taichi_core.pyd
0x7ff8d32d94f1: taichi::print_traceback in taichi_core.pyd
0x7ff8d3295b3f: taichi::print_traceback in taichi_core.pyd
0x7ff8d32958c7: taichi::print_traceback in taichi_core.pyd
0x7ff8d3205e7a: PyInit_taichi_core in taichi_core.pyd
0x7ff8d31e4a91: PyInit_taichi_core in taichi_core.pyd
0x7ff8d321acb8: PyInit_taichi_core in taichi_core.pyd
0x7ff8d3205055: PyInit_taichi_core in taichi_core.pyd
0x7ff8d3204a48: PyInit_taichi_core in taichi_core.pyd
0x7ff8d3097aae: PyInit_taichi_core in taichi_core.pyd
0x7ff8d300ec26: PyInit_taichi_core in taichi_core.pyd
0x7ff8d2fd6d9b: pybind11::error_already_set::discard_as_unraisable in taichi_core.pyd
0x7ff8e3026fe0: PyMethodDef_RawFastCallKeywords in python38.dll
0x7ff8e3025fa6: PyObject_MakeTpCall in python38.dll
0x7ff8e30293ea: PyMethod_Self in python38.dll
0x7ff8e30260ee: PyVectorcall_Call in python38.dll
0x7ff8e3093d75: PyType_Ready in python38.dll
0x7ff8e3025fa6: PyObject_MakeTpCall in python38.dll
0x7ff8e31035b8: PyEval_GetFuncDesc in python38.dll
0x7ff8e30fdf3c: PyEval_EvalFrameDefault in python38.dll
0x7ff8e31021d9: PyEval_EvalCodeWithName in python38.dll
0x7ff8e302676f: PyFunction_Vectorcall in python38.dll
0x7ff8e30260ee: PyVectorcall_Call in python38.dll
0x7ff8e3103851: PyEval_GetFuncDesc in python38.dll
0x7ff8e3100271: PyEval_EvalFrameDefault in python38.dll
0x7ff8e31021d9: PyEval_EvalCodeWithName in python38.dll
0x7ff8e302676f: PyFunction_Vectorcall in python38.dll
0x7ff8e30260ee: PyVectorcall_Call in python38.dll
0x7ff8e3103851: PyEval_GetFuncDesc in python38.dll
0x7ff8e3100271: PyEval_EvalFrameDefault in python38.dll
0x7ff8e31021d9: PyEval_EvalCodeWithName in python38.dll
0x7ff8e302676f: PyFunction_Vectorcall in python38.dll
0x7ff8e3025da3: PyObject_FastCallDict in python38.dll
0x7ff8e30273d0: PyObject_Call_Prepend in python38.dll
0x7ff8e3093d2a: PyType_Ready in python38.dll
0x7ff8e302623a: PyObject_Call in python38.dll
0x7ff8e3103851: PyEval_GetFuncDesc in python38.dll
0x7ff8e3100271: PyEval_EvalFrameDefault in python38.dll
0x7ff8e3026315: PyObject_Call in python38.dll
0x7ff8e3026698: PyFunction_Vectorcall in python38.dll
0x7ff8e3103593: PyEval_GetFuncDesc in python38.dll
0x7ff8e310014a: PyEval_EvalFrameDefault in python38.dll
0x7ff8e3026315: PyObject_Call in python38.dll
0x7ff8e3026698: PyFunction_Vectorcall in python38.dll
0x7ff8e3029129: PyCell_Set in python38.dll
0x7ff8e30293ea: PyMethod_Self in python38.dll
0x7ff8e30fb95c: PyOS_URandomNonblock in python38.dll
0x7ff8e30fdc52: PyEval_EvalFrameDefault in python38.dll
0x7ff8e3026315: PyObject_Call in python38.dll
0x7ff8e3026698: PyFunction_Vectorcall in python38.dll
0x7ff8e30292e9: PyMethod_Self in python38.dll
0x7ff8e3103593: PyEval_GetFuncDesc in python38.dll
0x7ff8e3100131: PyEval_EvalFrameDefault in python38.dll
0x7ff8e31021d9: PyEval_EvalCodeWithName in python38.dll
0x7ff8e3170b9f: PyRun_FileExFlags in python38.dll
0x7ff8e3170c91: PyRun_FileExFlags in python38.dll
0x7ff8e3170778: PyRun_StringFlags in python38.dll
0x7ff8e316e9f1: PyRun_InteractiveOneFlags in python38.dll
0x7ff8e316ebc0: PyRun_SimpleFileExFlags in python38.dll
0x7ff8e2f972dd: Py_hashtable_copy in python38.dll
0x7ff8e2f97f50: Py_hashtable_copy in python38.dll
0x7ff8e2f98e42: Py_RunMain in python38.dll
0x7ff8e2f98eb6: Py_Main in python38.dll
0x7ff6128114f8: OPENSSL_Applink in python.exe
0x7ff94dce7034: BaseThreadInitThunk in KERNEL32.DLL
0x7ff94eb22651: RtlUserThreadStart in ntdll.dll

Internal error occurred. Check out this page for possible solutions:
https://docs.taichi.graphics/lang/articles/misc/install
Traceback (most recent call last):
  File ".\tmp\test_autodiff.py", line 52, in <module>
    simu.substep()
  File ".\tmp\test_autodiff.py", line 37, in substep
    self.compute_U(
  File "C:\Users\Nangu\miniconda3\envs\misc\lib\site-packages\taichi\lang\tape.py", line 18, in __exit__
    self.grad()
  File "C:\Users\Nangu\miniconda3\envs\misc\lib\site-packages\taichi\lang\tape.py", line 27, in grad
    func.grad(*args)
  File "C:\Users\Nangu\miniconda3\envs\misc\lib\site-packages\taichi\lang\shell.py", line 39, in new_call
    ret = old_call(*args, **kwargs)
  File "C:\Users\Nangu\miniconda3\envs\misc\lib\site-packages\taichi\lang\kernel_impl.py", line 724, in __call__
    return self.compiled_functions[key](*args)
  File "C:\Users\Nangu\miniconda3\envs\misc\lib\site-packages\taichi\lang\kernel_impl.py", line 682, in func__
    t_kernel(launch_ctx)
RuntimeError: [reverse_segments.cpp:taichi::lang::irpass::reverse_segments@72] Invalid program input for autodiff. Please check the documentation for the "Kernel Simplicity Rule":
https://docs.taichi.graphics/lang/articles/advanced/differentiable_programming#kernel-simplicity-rule

试了一下,把 compute_U 从类中抽出来就可以了。

@ti.kernel
def compute_U(x: ti.template(), U: ti.template(), N: ti.int32):
    for i, j in ti.ndrange(N, N):
        r = x[i] - x[j]
        # r.norm(1e-3) is equivalent to ti.sqrt(r.norm()**2 + 1e-3)
        # This is to prevent 1/0 error which can cause wrong derivative
        U[None] += -1 / r.norm(1e-3)  # U += -1 / |r|

1 个赞