Jittor 自定义 forward 和 backward

请问 Jittor 是否有类似于 torch.autograd.Function 的方法让用户自定义前向和反向函数?如果没有,怎么把下述代码迁移到 Jittor 的实现上?

class CheckpointFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])
        with jt.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
        with jt.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = torch.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        return (None, None) + input_grads

jittor是没有autograd模块的,前向传播函数叫另一个名字execute,是nn模块的Midule类的子类,另外jittor是不存在反向传播函数接口的,它把反向函数整合进了优化器类optimizer里,但具体怎么重载backward函数,是否需要重载backward函数,我也不太了解。

这是我了解到关于jittor的一些片面的东西,希望能帮到你

使用jittor.Function(jittor — Jittor 1.3.9.2 文档) 即可