建议添加jittor.variance对标torch.var

管理员您好,在完成pytorch转jittor的过程中,我发现torch的var(方差)在jittor中没有对应的API,我自己写了一个简陋版本的var临时使用,希望jittor官方在后续维护时能给出对应的函数。

def my_var(input, dim=None, keepdim=False, unbiased=True):
    shape = input.shape
    new_shape = list(input.shape)
    if dim is None:
        dim = list(range(len(shape)))
    mean = torch.mean(input=input, dim=dim, keepdim=True)
    n = 1
    for d in dim:
        n *= shape[d]
        mean = torch.cat([mean for _ in range(shape[d])],dim=d)
        new_shape[d] = 1
    sqr = (input-mean)**2
    sqr = torch.sum(sqr, dim=dim, keepdim=False)
    if unbiased:
        n -= 1
    sqr /= n
    if keepdim:
        sqr = sqr.view(new_shape)
    return sqr

感谢您的反馈,我们会尽快添加该算子的支持

自己用jittor实现了variance,比用pytorch实现的方法快了500倍woc

import torch, jittor, numpy
import time
def my_var_torch(input, dim=None, keepdim=False, unbiased=True):
    shape = input.shape
    new_shape = list(input.shape)
    if dim is None:
        dim = list(range(len(shape)))
    mean = torch.mean(input=input, dim=dim, keepdim=True)
    mean = torch.broadcast_to(mean, shape)
    n = 1
    for d in dim:
        n *= shape[d]
        new_shape[d] = 1
    sqr = (input-mean)**2
    sqr = torch.sum(sqr, dim=dim, keepdim=False)
    if unbiased:
        n -= 1
    sqr /= n
    if keepdim:
        sqr = sqr.view(new_shape)
    return sqr
def my_var_jittor(input, dims=None, keepdims=False, unbiased=True):
    shape = input.shape
    new_shape = list(input.shape)
    if dims is None:
        dims = list(range(len(shape)))
    mean = jittor.mean(input, dims, keepdims=True)
    mean = jittor.broadcast(mean, shape)
    n = 1
    for d in dims:
        n *= shape[d]
        new_shape[d] = 1
    sqr = (input-mean)**2
    sqr = jittor.sum(sqr, dims=dims, keepdims=False)
    if unbiased:
        n -= 1
    sqr /= n
    if keepdim:
        sqr = sqr.view(new_shape)
    return sqr
x = numpy.random.rand(7,9,11,13,15,17,19)
x1 = torch.from_numpy(x)
x2 = jittor.Var(x)
keepdim = False
t0 = time.process_time()
for _ in range(100):
    a = my_var_torch(x1,dim=[1, 3, 4],keepdim=keepdim)
print("ans from my_var_torch =", torch.sum(a))
t1 = time.process_time()
for _ in range(100):
    a = my_var_jittor(x2,dims=[1, 3, 4],keepdims=keepdim)
print("ans from my_var_jittor1 =", jittor.sum(a))
t2 = time.process_time()
for _ in range(100):
    a = torch.var(x1,dim=[1, 3, 4],keepdim=keepdim)
print("ans from torch_var =", torch.sum(a))
t3 = time.process_time()
print("time myvar_torch: {}".format(t1-t0))
print("time myvar_jittor: {}".format(t2-t1))
print("time pytorch var: {}".format(t3-t2))

输出如下:

ans from my_var_torch = tensor(2072.7649, dtype=torch.float64)
ans from my_var_jittor1 = jt.Var([2072.7637], dtype=float32)
ans from torch_var = tensor(2072.7649, dtype=torch.float64)
time myvar_torch: 94.25290709000001
time myvar_jittor: 0.08337130399999637
time pytorch var: 45.863336922

十分感谢您的贡献!您可以在https://github.com/Jittor/jittor 提一个pr,我们会尽快merge。
以及您可以加一下计图开发者交流qq群:761222083,我们可以给您一些计图纪念品。

谢谢,我好像找到问题了,用cuda的话时间和pytorch差不多 :rofl:

@fangtiancheng ,

您好,感谢您的建议与代码,

我目前已经在 Jittor 仓库上基于您的代码创建了一个 pull request,谢谢您对 jittor 的贡献。