自己用jittor实现了variance,比用pytorch实现的方法快了500倍woc
import torch, jittor, numpy
import time
def my_var_torch(input, dim=None, keepdim=False, unbiased=True):
shape = input.shape
new_shape = list(input.shape)
if dim is None:
dim = list(range(len(shape)))
mean = torch.mean(input=input, dim=dim, keepdim=True)
mean = torch.broadcast_to(mean, shape)
n = 1
for d in dim:
n *= shape[d]
new_shape[d] = 1
sqr = (input-mean)**2
sqr = torch.sum(sqr, dim=dim, keepdim=False)
if unbiased:
n -= 1
sqr /= n
if keepdim:
sqr = sqr.view(new_shape)
return sqr
def my_var_jittor(input, dims=None, keepdims=False, unbiased=True):
shape = input.shape
new_shape = list(input.shape)
if dims is None:
dims = list(range(len(shape)))
mean = jittor.mean(input, dims, keepdims=True)
mean = jittor.broadcast(mean, shape)
n = 1
for d in dims:
n *= shape[d]
new_shape[d] = 1
sqr = (input-mean)**2
sqr = jittor.sum(sqr, dims=dims, keepdims=False)
if unbiased:
n -= 1
sqr /= n
if keepdim:
sqr = sqr.view(new_shape)
return sqr
x = numpy.random.rand(7,9,11,13,15,17,19)
x1 = torch.from_numpy(x)
x2 = jittor.Var(x)
keepdim = False
t0 = time.process_time()
for _ in range(100):
a = my_var_torch(x1,dim=[1, 3, 4],keepdim=keepdim)
print("ans from my_var_torch =", torch.sum(a))
t1 = time.process_time()
for _ in range(100):
a = my_var_jittor(x2,dims=[1, 3, 4],keepdims=keepdim)
print("ans from my_var_jittor1 =", jittor.sum(a))
t2 = time.process_time()
for _ in range(100):
a = torch.var(x1,dim=[1, 3, 4],keepdim=keepdim)
print("ans from torch_var =", torch.sum(a))
t3 = time.process_time()
print("time myvar_torch: {}".format(t1-t0))
print("time myvar_jittor: {}".format(t2-t1))
print("time pytorch var: {}".format(t3-t2))
输出如下:
ans from my_var_torch = tensor(2072.7649, dtype=torch.float64)
ans from my_var_jittor1 = jt.Var([2072.7637], dtype=float32)
ans from torch_var = tensor(2072.7649, dtype=torch.float64)
time myvar_torch: 94.25290709000001
time myvar_jittor: 0.08337130399999637
time pytorch var: 45.863336922