常见 Torch 与 Jittor 接口的对应关系

由于设计理念和底层逻辑的区别,Jittor 中有部分接口与 Torch 是不一样的。在这里整理如下,以供参考。

Torch Jittor 解释
torch.autograd.variable() jt.Var 支持求导自动,无需单独包装
torch.cuda.empty_cache() jt.gc()
torch.cuda.is_available() jt.flags.has_cuda Jittor 里判单是否可用 cuda 的方式是使用标志变量。此外,启用 cuda 的方式是设置变量 jt.flags.use_cuda = 1
torch.cuda.synchronize() jt.sync_all(True)
torch.Module.register_buffer() jittor中在模型里直接定义 buffer,然后设置其不需要梯度 jt.Var.stop_grad()
torch.index_select() jt.reindex() 参考此贴
torch.sort() jt.argsort() jt.argsort 同时返回索引和排序结果
torch.Tensor.contiguous() Jittor 的 Var 任何时候都是保持连续的,因此 Jittor 代码里不需要这个函数
torch.Tensor.to(device) Jittor 里通过全局标志 jt.flags.use_cuda 来设置 var 是否在 gpu 上运算,不需要对单独的模型或者 tensor 设置
3 个赞