例如我有一个nx3的变量,第三列是分组列,我想分组聚合,得到第三列相同的数据行的平均值,请问这在计图中如何实现?
您好,能给一个例子吗?
比如我有一个张量
jt.float32([[1,2,3],[4,5,6],[7,2,3],[2,3,6],[4,6,6]])
我想得到
jt.float32([[4,2,3],[5,7,6]])
[4,2,3] 是 [(1+7)/2, (2+2)/2, 3]
[5,7,6] 是 [(4+2+4)/3, (5+3+6)/3, 6]
类似这样。
当然也可以数据列和索引列分开,类似这样:
jt.float32([[1,2],[4,5],[7,2],[2,3],[4,6]])
jt.int8([3,6,3,6,6])
然后得到
jt.float32([[4,2],[5,7]])
和
jt.int8([3,6])
您可以参考以下代码操作,其核心思想是利用 reindex_reduce
操作,借助索引实现数据的加法
>>> x = jt.randint(5, shape=(10,)) # 索引
>>> y = y = jt.rand((10, 2)) # 数据
>>> x
jt.Var([2 2 4 1 3 3 0 3 1 3], dtype=int32)
>>> x_max = x.max().item()
>>> x_max
4
>>> x_count = jt.ones(10).reindex_reduce("sum", [x_max], ["@e0(i0)"], extras=[x])
>>> x_count # 每个索引的数量
jt.Var([1. 2. 2. 4.], dtype=float32)
>>> y_sum = y.reindex_reduce("sum", [x_max, 2], ["@e0(i0)", "i1"], extras=[x]) # 使用 reindex_reduce 得到相同索引的和
>>> y_sum
jt.Var([[0.29100886 0.28906357]
[1.2039766 1.1563721 ]
[1.5233977 0.9310594 ]
[1.4984614 2.2315395 ]], dtype=float32)
>>> y_mean = y_sum / x_count[:, None] # 得到数据的平均
>>> y_mean
jt.Var([[0.29100886 0.28906357]
[0.6019883 0.57818604]
[0.76169884 0.4655297 ]
[0.37461534 0.5578849 ]], dtype=float32)
>>> unique_index = jt.where(x_count) # 去重后的索引
>>> unique_index
[jt.Var([0 1 2 3], dtype=int32)]
好的,我明白了,谢谢。