The following code includes several ti.var. The code works fine when `batch_size, input_feature, hidden_feature, out_feature’ are all small. However, if I use some numbers like 64, which is normal in neural networks, the program keeps running and does not give a result.
import torch
from torch.autograd import gradcheck
import torch.nn.functional as F
from torch.autograd import gradcheck
import math
import torch.nn as nn
import taichi as ti
ti.get_runtime().set_default_fp(ti.f32)
real = ti.f32
class LinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input_data, weight_0, bias_0, weight_1, bias_1,
ti_data, ti_weight_0, ti_bias_0, ti_weight_1, ti_bias_1, ti_output_0, ti_output_1, ti_kernel):
ctx.ti_output_1 = ti_output_1
ctx.ti_kernel = ti_kernel
ctx.ti_data = ti_data
ctx.ti_weight_0 = ti_weight_0
ctx.ti_bias_0 = ti_bias_0
ctx.ti_weight_1 = ti_weight_1
ctx.ti_bias_1 = ti_bias_1
ti_data.from_torch(input_data)
ti_weight_0.from_torch(weight_0)
ti_bias_0.from_torch(bias_0)
ti_weight_1.from_torch(weight_1)
ti_bias_1.from_torch(bias_1)
ti_kernel()
return ti_output_1.to_torch()
@staticmethod
def backward(ctx, grad_output_1):
ti.clear_all_gradients()
grad_input_data = grad_weight_0 = grad_bias_0 = grad_weight_1 = grad_bias_1 = None
ctx.ti_output_1.grad.from_torch(grad_output_1)
ctx.ti_kernel(__gradient=True)
if ctx.needs_input_grad[0]:
grad_input_data = ctx.ti_data.grad.to_torch()
if ctx.needs_input_grad[1]:
grad_weight_0 = ctx.ti_weight_0.grad.to_torch()
if ctx.needs_input_grad[2]:
grad_bias_0 = ctx.ti_bias_0.grad.to_torch()
if ctx.needs_input_grad[3]:
grad_weight_1 = ctx.ti_weight_1.grad.to_torch()
if ctx.needs_input_grad[4]:
grad_bias_1 = ctx.ti_bias_1.grad.to_torch()
return grad_input_data, grad_weight_0, grad_bias_0, grad_weight_1, grad_bias_1,\
None, None, None, None, None, None, None, None
class Linear(nn.Module):
def __init__(self, input_feature, hidden_feature, output_feature):
super(Linear, self).__init__()
# taichi parameter holders
self.ti_data = ti.var(dt=real, shape=(batch_size, input_feature), needs_grad=True)
self.ti_weight_0 = ti.var(dt=real, shape=(input_feature, hidden_feature), needs_grad=True)
self.ti_bias_0 = ti.var(dt=real, shape=hidden_feature, needs_grad=True)
self.ti_output_0 = ti.var(dt=real, shape=(batch_size, hidden_feature), needs_grad=True)
self.ti_weight_1 = ti.var(dt=real, shape=(hidden_feature, out_feature), needs_grad=True)
self.ti_bias_1 = ti.var(dt=real, shape=out_feature, needs_grad=True)
self.ti_output_1 = ti.var(dt=real, shape=(batch_size, out_feature), needs_grad=True)
# torch parameters
self.weight_0 = nn.Parameter(torch.Tensor(input_feature, hidden_feature))
self.bias_0 = nn.Parameter(torch.Tensor(hidden_feature))
self.weight_1 = nn.Parameter(torch.Tensor(hidden_feature, out_feature))
self.bias_1 = nn.Parameter(torch.Tensor(out_feature))
self.weight_0.data.normal_(0, math.sqrt(2. / hidden_feature / input_feature))
self.weight_1.data.normal_(0, math.sqrt(2. / hidden_feature / output_feature))
@ti.classkernel
def linear_kernel(self):
for i in range(batch_size):
for j in ti.static(range(hidden_feature)):
dummy = 0.0
for k in ti.static(range(input_feature)):
dummy += self.ti_data[i, k] * self.ti_weight_0[k, j]
dummy += self.ti_bias_0[j]
self.ti_output_0[i, j] = ti.max(dummy, 0)
for j in ti.static(range(out_feature)):
dummy = 0.0
for k in ti.static(range(hidden_feature)):
dummy += self.ti_output_0[i, k] * self.ti_weight_1[k, j]
dummy += self.ti_bias_1[j]
self.ti_output_1[i, j] = dummy
def forward(self, input_data):
return LinearFunction.apply(input_data, self.weight_0, self.bias_0, self.weight_1, self.bias_1,
self.ti_data, self.ti_weight_0, self.ti_bias_0, self.ti_weight_1, self.ti_bias_1,
self.ti_output_0, self.ti_output_1, self.linear_kernel)
if __name__ == '__main__':
batch_size = 32
input_feature = 64
hidden_feature = 128
out_feature = 64
data = torch.rand(batch_size, input_feature, dtype=torch.float32, requires_grad=True)
linear = Linear(input_feature, hidden_feature, out_feature)
test = gradcheck(linear, data, eps=1e-3, atol=1e-4)
print(test)