以下代码计算1000亿以内所有素数之和,我用含义基本一样的python代码计算得到结果为201467077743744681014,taichi代码得到结果为-1447107067060386762,请问可能是什么原因导致的?
(注:用pypy跑python版本代码应该在几秒内就可以得到结果)
以下为taichi代码
import taichi as ti
ti.init(arch=ti.cpu,default_ip=ti.i64)
N = 1000010
prime=ti.field(dtype=int, shape=(N,))
id1=ti.field(dtype=int, shape=(N,))
id2=ti.field(dtype=int, shape=(N,))
flag=ti.field(dtype=bool, shape=(N,))
ncnt=ti.field(dtype=ti.i64, shape=())
m=ti.field(dtype=int, shape=())
g=ti.field(dtype=ti.i64, shape=(N,))
tsum=ti.field(dtype=int, shape=(N,))
a=ti.field(dtype=int, shape=(N,))
@ti.func
def IDcalc(x:ti.i64,n:ti.i64,T:ti.i64)->ti.i64:
r=0
if x<=T:
r=id1[x]
else:
r=id2[n//x]
return r
@ti.kernel
def test(n:ti.i64)->ti.i64:
T=int(ti.sqrt(n))+1
for _ in range(1):
for i in range(2,T):
if not flag[i]:
ncnt[None]+=1
prime[ncnt[None]]=i
tsum[ncnt[None]]=tsum[ncnt[None]-1]+i
j=1
while j<=ncnt[None] and i*prime[j]<=T:
flag[i*prime[j]]=True
if i%prime[j]==0:
break
j+=1
l=1
for _ in range(1):
while l<=n:
a[m[None]+1]=n//l
if a[m[None]+1]<=T:
id1[a[m[None]+1]]=m[None]+1
else:
id2[n//a[m[None]+1]]=m[None]+1
g[m[None]+1]=a[m[None] + 1] * (a[m[None] + 1] + 1) // 2 - 1
m[None]+=1
l=n//(n//l)+1
for _ in range(1):
for i in range(1,ncnt[None]+1):
for j in range(1,m[None]+1):
if prime[i]*prime[i]>a[j]:
break
g[j]=g[j]-prime[i]*(g[IDcalc(a[j] // prime[i],n,T)]-tsum[i-1])
return g[IDcalc(n,n,T)]
limit=int(1e11)
import time
start_time = time.perf_counter()
print(f"Total prime numbers found: {test(limit)}", )
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print("Total time: ", elapsed_time, " seconds")
以下为python代码
import time
def solve(n):
N = 1000010
prime = [0] * N
id1 = [0] * N
id2 = [0] * N
flag = [0] * N
ncnt = 0
m = 0
g = [0] * N
sum = [0] * N
a = [0] * N
T = int(n ** 0.5) + 1
def ID(x):
return id1[x] if x <= T else id2[n // x]
def calc(x):
return x * (x + 1) // 2 - 1
def f(x):
return x
def init():
nonlocal T, ncnt, m
T = int(n ** 0.5) + 1
for i in range(2, T):
if not flag[i]:
ncnt += 1
prime[ncnt] = i
sum[ncnt] = sum[ncnt - 1] + i
j = 1
while j <= ncnt and i * prime[j] <= T:
flag[i * prime[j]] = 1
if i % prime[j] == 0:
break
j += 1
l = 1
while l <= n:
a[m + 1] = n // l
if a[m + 1] <= T:
id1[a[m + 1]] = m + 1
else:
id2[n // a[m + 1]] = m + 1
g[m + 1] = calc(a[m + 1])
m += 1
l = n // (n // l) + 1
for i in range(1, ncnt + 1):
for j in range(1, m + 1):
if prime[i] * prime[i] > a[j]:
break
g[j] = g[j] - prime[i] * (g[ID(a[j] // prime[i])] - sum[i - 1])
def solve(x):
nonlocal n
if x <= 1:
return x
n = x
init()
return g[ID(n)]
return solve(n)
print("begain...")
t1 = time.perf_counter()
n = int(1e11)
print(solve(n))
t2 = time.perf_counter()
print("run times:", t2 - t1, "s")
print("end...")