请问计图如何实现分组聚合求平均?

例如我有一个nx3的变量,第三列是分组列,我想分组聚合,得到第三列相同的数据行的平均值,请问这在计图中如何实现?

您好,能给一个例子吗?

比如我有一个张量

jt.float32([[1,2,3],[4,5,6],[7,2,3],[2,3,6],[4,6,6]])

我想得到

jt.float32([[4,2,3],[5,7,6]])

[4,2,3] 是 [(1+7)/2, (2+2)/2, 3]
[5,7,6] 是 [(4+2+4)/3, (5+3+6)/3, 6]
类似这样。

当然也可以数据列和索引列分开,类似这样:

jt.float32([[1,2],[4,5],[7,2],[2,3],[4,6]])
jt.int8([3,6,3,6,6])

然后得到

jt.float32([[4,2],[5,7]])

jt.int8([3,6])

您可以参考以下代码操作,其核心思想是利用 reindex_reduce 操作,借助索引实现数据的加法

>>> x = jt.randint(5, shape=(10,)) # 索引
>>> y = y = jt.rand((10, 2)) # 数据
>>> x
jt.Var([2 2 4 1 3 3 0 3 1 3], dtype=int32)
>>> x_max = x.max().item()
>>> x_max
4
>>> x_count = jt.ones(10).reindex_reduce("sum", [x_max], ["@e0(i0)"], extras=[x])
>>> x_count # 每个索引的数量
jt.Var([1. 2. 2. 4.], dtype=float32)
>>> y_sum = y.reindex_reduce("sum", [x_max, 2], ["@e0(i0)", "i1"], extras=[x]) # 使用 reindex_reduce 得到相同索引的和
>>> y_sum
jt.Var([[0.29100886 0.28906357]
 [1.2039766  1.1563721 ]
 [1.5233977  0.9310594 ]
 [1.4984614  2.2315395 ]], dtype=float32)
>>> y_mean = y_sum / x_count[:, None] # 得到数据的平均
>>> y_mean
jt.Var([[0.29100886 0.28906357]
 [0.6019883  0.57818604]
 [0.76169884 0.4655297 ]
 [0.37461534 0.5578849 ]], dtype=float32)
>>> unique_index = jt.where(x_count) # 去重后的索引
>>> unique_index
[jt.Var([0 1 2 3], dtype=int32)]

好的,我明白了,谢谢。