看了api文档和《 元算子:通过元算子实现自己的卷积层》那篇教程,还是不是很懂这两个算子如何使用,有没有更详细或者有更多例子的对这两个算子的使用说明?
您好!
Jittor 中的 reindex
和 reindex_reduce
是非常灵活的算子,可以用于实现多种数据重排列、聚合的操作,也可以实现 pytorch 的 scatter
等接口。
reindex 原理
reindex
的目的是实现变量x
的重索引,即将输入变量 x
中的数据对应到输出变量 y
中的一个或者多个位置上。
接下来我将解释 reindex
的最重要参数, indexes
。假设输入变量是 n
维,输出变量是 m
维。
indexes
是 m
个 C++ 表达式的 list,这与输出维度 m
一致。C++表达式里使用了预定义的变量 i0, i1, ..., in
, 代表了输出变量中的一个索引位置,即 y[i0, i1, ..., in]
。
第 k
个表达式计算了输入的第 k 维,indexes
告诉了我们 y[i0, i1, ..., in]
的来源是 x[indexes[0], indexes[1], ..., indexes[m]]
。
reindex
的使用方式有两种:jt.reindex(x, ...)
或者 x.reindex(...)
,两者是等价的。
reindex 的简单举例
我将举一个非常简单的例子,帮助您理解。假设我有一个一维变量 x
,形状为 [n]
,我想得到一个新的变量 y
,是 x
中每个元素复制 2 次得到的,即 y = [x[0], x[0], x[1], x[1], ..., x[n], x[n]]
。那么您可以借助 reindex
完成这件事:
>>> x = jt.arange(3)
>>> x
jt.Var([0 1 2], dtype=int32)
>>> y = x.reindex([6], ['i0 / 2'])
>>> y
jt.Var([0 0 1 1 2 2], dtype=int32)
这里的 'i0/2'
即用来表达 y[i0] = x[i0 / 2]
。
reindex 的文档样例解释
在 Jittor 的文档给了一个 reindex 参与卷积计算的例子参与卷积计算的例子。输入的图像特征维度为 [N,H,W,C]
,卷积核的维度为 [Kh, Kw, _C, Kc
]。对于x
的每个像素位置 [h, w]
,reindex
将其扩展成了一个与卷积核相同大小的模板,包含了来自 h ~ h+Kh, w ~ w + Kw
的一小块区域的特征。
逐行代码解释如下:
xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
# Nid: 即 batch 的 id,输出变量 xx[i0, ...] 来源的 batch id 也是 i0
'i0',
# xx 的第 [i1(h), i2(w)] 个模板中的第 [i3(hh), i4(ww)]个像素特征,
# 来源于输入变量的第 [i1+i3, i2+i4] 个像素
'i1+i3',
'i2+i4',
# Cid: 输出变量的特征第 [i5] 维,来源于输入变量特征的第 [i5] 维
'i5',
])
做完 reindex
就可以直接与卷积核做矩阵逐元素乘法,再求和,就得到了卷积的结果。
reindex_reduce 的原理
reindex
是一个 1对多 的映射,即同一个输入位置可以复制给多个输出位置。对应的,reindex_reduce
则是一个 多对1 的映射,即将多个输入位置复制给一个输出位置;reindex_reduce
对多个输入使用指定的 reduce
操作,得到最终的输出。
与 reindex
相反,reindex_reduce
的 indexes
的表达式个数与输出变量维度一致。预定义变量 i0, i1, ..., in
是输入变量的下标索引。注意,文档里的 x
是输出,y
是输入。
目前可用的 reduce
操作包括 add
, max
, multiply
等二元运算。注意,reindex_reduce
暂时不直接支持求 mean
。但您可以通过 reindex_reduce
分别计算 reduce 的次数与求和,再两者相除求均值。
reindex_reduce 简单举例
假设我有一个一维变量 x
,形状为 [n]
,我想得到一个新的变量 y
,是 x
中每个 2 个相邻元素的和,即 y = [x[0] + x[1], x[1] + x[2], ..., x[n-1] + x[n]]
。那么您可以借助 reindex_reduce
完成:
>>> x = jt.arange(7)
>>> x
jt.Var([0 1 2 3 4 5 6], dtype=int32)
>>> y = x.reindex_reduce('add', [4], ['i0 / 2'])
>>> y
jt.Var([1 5 9 6], dtype=int32)
这里的 'i0/2'
即用来表达 y[i0 / 2] += x[i0]
。
reindex 和 reindex_reduce 的进阶用法
1.借助额外的索引数组对输入重排列
reindex
和 reindex_reduce
可以传入多个额外的变量实现输入数据的重索引。一个简单的举例是,假如我希望一个输入的多维数据按照指定额外的排列方式重新排列某一个维度,那么代码为:
>>> x = jt.arange(8).reshape(4, 2)
>>> x
jt.Var([[0 1]
[2 3]
[4 5]
[6 7]], dtype=int32)
>>> # 指定输出 y[0] = x[1], y[1] = x[3], y[2] = x[0], y[3] = x[2]
>>> e = jt.array([1, 3, 0, 2])
>>> x.reindex([4, 2], ['@e0(i0)', 'i1'], extras=[e])
jt.Var([[2 3]
[6 7]
[0 1]
[4 5]], dtype=int32)
其中 @e0
为 extras
的第 0 个变量,如果 extras
中有多个,则可以使用 @e1, @e2, ...
。
上面这个例子与 x[e]
的结果相同。但是当 e
的维度大于 1 时,您无法通过下标索引实现重排列,但是 reindex 支持多维的 e
。
非常感谢,这段说明我想也许可以加到教程中。
妙啊,太妙了
你好,overflow_conditions这个参数我不太理解,可以麻烦您讲解一下吗?
overflow_condition 里是一组表达式,当表达式里任意一个成立时,被reindex的元素就被设置为 overflow_value。原理可以参考文档里的伪代码
jittor — Jittor 1.3.5.24 文档 (tsinghua.edu.cn)
举一个例子,要将所有下标为 3 的倍数的元素设为-1,可以采用以下代码:
>>> x = jt.arange(8)
>>> x.reindex([8], ['i0'], overflow_conditions=['i0%3==0'], overflow_value=-1)
jt.Var([-1 1 2 -1 4 5 -1 7], dtype=int32)