pytorch 代码转 jittor 使用心得

前言

为了实现大作业中复现 jittor 代码的任务,这几天花了将近二十个小时在 jittor 环境的配置上,颇有一番折磨心得体会,在这里记录一下。

需要注意的是,jittor 与 torch 有一些关键的区别:

  1. jittor 的基础类型是 Var,对应 torch 里边的 Tensor
  2. jittor.nn.Module 里边,前向计算的函数名是 execute,但对应到 torch.nn.Module,函数名是 forward

如果你也有类似的将 pytorch 代码转 jittor 代码的需求,即使你要转换的并非是我要运行的 P-tuning-v2,也强烈建议按下面的探索流程手动操作一下,不需要太多时间,熟悉了这样的处置错误流程之后再动手会节省你非常非常多时间。

环境初步配置

下面的说法都是基于我自己大作业复现 P-tuning-v2 的指令,我会标记出来哪些是可以根据自己需求替换和更改的部分。

首先当然是新建一个 conda 环境用来给自己鼓捣了:

conda create -n new_jittor python=3.8.5
conda activate new_jittor

这里的 python 版本 3.8.5 和环境名称 new_jittor 都是根据自己需求来改的。

有几个与 pytorch 转 jittor 相关的库需要着重提到一下:

jtorch

这个库主要是用于支持将 pytorch 代码整个地替换为 jittor 代码。

需要注意的是,一定不要按照 README 里面所说的通过 pip install jtorch 来安装,而是把库 clone 下来手动安装,因为这个库里面是最新版的,pip 上的版本比较老……

而且必须要是 JitterRepos/jtorch,这里才是最新版的,其他库都是旧版……

指令如下:

git clone git@github.com:JittorRepos/jtorch.git
python -m pip install jtorch/.
python -m pip show jtorch # 显示 jtorch 库版本,应该 >= 0.2.0
python -m pip show jittor # 显示 jittor 库版本,应该 >= 1.3.9.14

注意安装 jtorch 的同时也会安装 jittor 库本身。

transformers_jittor

一个用处不大但是装一下也好的库。(另外这是一个 star 数为 0 的库)

transformers 库是个在学习领域广泛应用的库,会从 tensorflow, pytorch 等中挑一个依赖然后运行。

这里本质上是在 4.26.1 版本的 transformers 库上稍作改动,以部分地兼容 jittor,但是兼容了吗?如兼容。

不管咋说先装上吧:

git clone git@github.com:JittorRepos/transformers_jittor.git
python -m pip install transformers_jittor/.
python -m pip show transformers # 显示 transformers 库版本,应当为 4.26.1

常见报错合集

经过了上面的初步配置之后,就到了兵来将挡水来土掩的时间了……

这里给出一些常见报错的解决方法示范与原理解析。

探索问题的过程很长,也会附上省流总结版方便你直接修改对应地点。

运行 jtorch 示例代码

optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)

经过了上面的初步配置之后尝试跑一下 jtorch 库给出的示例代码:

# -*- coding: utf-8 -*-
import random
import torch
import math


class DynamicNet(torch.nn.Module):
    def __init__(self):
        """
        In the constructor we instantiate five parameters and assign them as members.
        """
        super().__init__()
        self.a = torch.nn.Parameter(torch.randn(()))
        self.b = torch.nn.Parameter(torch.randn(()))
        self.c = torch.nn.Parameter(torch.randn(()))
        self.d = torch.nn.Parameter(torch.randn(()))
        self.e = torch.nn.Parameter(torch.randn(()))

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 4, 5
        and reuse the e parameter to compute the contribution of these orders.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same parameter many
        times when defining a computational graph.
        """
        y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
        for exp in range(4, random.randint(4, 6)):
            y = y + self.e * x ** exp
        return y

    def string(self):
        """
        Just like any class in Python, you can also define custom method on PyTorch modules
        """
        return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?'


# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)

# Construct our model by instantiating the class defined above
model = DynamicNet()

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
for t in range(60000):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    if t % 2000 == 1999:
        print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # print(torch.liveness_info())

print(f'Result: {model.string()}')

把这份示例代码扔到 test_jtorch.py 里面然后运行,经过了漫长的 jittor 编译环节,你会遇到如下报错:

Traceback (most recent call last):
  File "test_jittor.py", line 53, in <module>
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
  File "/home/aiuser/anaconda3/envs/new_jittor/lib/python3.8/site-packages/jittor/optim.py", line 307, in __init__
    super().__init__(params, lr)
  File "/home/aiuser/anaconda3/envs/new_jittor/lib/python3.8/site-packages/jittor/optim.py", line 31, in __init__
    assert len(params) > 0, "Length of parameters should not be zero"
TypeError: object of type 'list_iterator' has no len()

也就是说,这里的报错是因为 model.parameters() 是一个 list_iterator 类型,不能使用 len(),于是报错。

你可能会觉得这里应该是 jittor 内部优化器实现的问题,本来按理来说 model.parameters() 的返回值就该是一个迭代器,怎么可以使用 len() 去提取长度呢?

我之前也是这样想的,于是使用 ctrl 左键点开上面报错信息里面的 optim.py,发现他的实现就是假定了传入的 params 参数是一个 list

经过了漫长的理解和修改,我终于发现,改这里的 optim.py 的实现是完全不可行的,jittor 内部本来就认为 model.parameters() 就该是个 list

对我来说这是个非常奇怪的事情,因为大模型参数应该很巨大,动辄几个 G 几十个 G,直接拿 list 往内存或者显存里面加载那还得了?但这里只能先按下不表,留作后面的问题

接下来就是另一个问题了:如果 jittor 认为 model.parameters() 理应是一个 list,那为什么这里的返回值偏偏又是个 list_iterator?这不是内部实现不一致吗?

这里我也不卖关子了,说实话找到这个问题的解答花了我非常长的时间。

将上面测试代码第七行中的

class DynamicNet(torch.nn.Module):

替换为:

import jittor
class DynamicNet(jittor.nn.Module):

然后再尝试在后面 print(model.parameters()),你会发现突然他就变成了一个正常的 list,而非改之前的 list_operator

按照 jtorch 的实现,torch.nn.Module 按理来说是和 jittor.nn.Module 等价的,万恶之源在 jtorch 里边。

找一下 jtorch 的源文件。如果你使用 vscode,先通过 ctrl+shift+p 选择对应的 conda 环境解释器,再使用 ctrl 左键点击 import torch 里面的 torch,就可以进入源码界面。

如果不使用 vscode,也可以手动去 conda 目录下找,比如我这里是

/home/aiuser/anaconda3/envs/new_jittor/lib/python3.8/site-packages/torch

你会发现他的 __init__.py 长这样:

import os
os.environ["FIX_TORCH_ERROR"] = "0"

import jittor as jt
import jtorch
from jtorch import *
__version__ = "2.0.0"

...
sys.modules["torch.nn"] = load_mod("jtorch.nn")
sys.modules["torch.nn.functional"] = load_mod("jtorch.nn")
sys.modules["torch.nn.parallel"] = load_mod("jtorch.distributed")
...

也就是说,jtorch 的原理是直接把 torch 相关的操作替换掉,比如说把 torch.nn 直接覆盖为 jtorch.nn 啥的,把整个 torch 库的操作全部覆盖掉了,然后把 torch 的版本强行设置为了 2.0.0。和 C++ 里面的 define 有异曲同工之妙。

那问题大概出在 jtorch.nn.Module 里边,切换到 jtorch/nn/__init__.py 里边看看:(方法类似,ctrl 左键点 jtorch,然后翻目录)

class Module(ModuleMisc, jt.Module):

    def __call__(self, *args, **kw):
        return self.execute(*args, **kw)

    def execute(self, *args, **kw):
        return self.forward(*args, **kw)

    def get_submodule(self, target: str):
        if target == "":
            return self

        atoms: List[str] = target.split(".")
        mod: jt.nn.Module = self

        for item in atoms:
            if not hasattr(mod, item):
                raise AttributeError(mod._get_name() + " has no "
                                     "attribute `" + item + "`")

            mod = getattr(mod, item)

            if not isinstance(mod, jt.nn.Module):
                raise AttributeError("`" + item + "` is not "
                                     "an nn.Module")
        return mod

也就是说,jtorch.nn.Module 确实和 jittor.nn.Module 有所不同,在继承的基础上改了点东西。

这里改动的好处在于,可以无缝承接 torch 实现里面对于 forward 的实现,可以理解为将 jittor 库中调用 model.execute(input) 或者 model(input) 覆盖为了对于 forward 的调用。

关注一下它另一个继承的 ModuleMisc,点开看看:

class ModuleMisc:
    def parameters(self):
        return iter(super().parameters())

    def load_state_dict(self, state_dict, strict=False):
        return super().load_state_dict(state_dict)

    def to(self, device=None,dtype=None):
        ''' do nothing but return its self'''
        return self
    def register_parameter(self,name,data):
        self.name = data

    def buffers(self):
        for _, buf in self.named_buffers():
            yield buf

看见第二行恍然大悟了。如果使用 torch.nn.Module,会自动替换为 jtorch.nn 里面的 Module,然后此处 ModuleMisc 的实现会让它在继承 jittor.nn.Module 的基础上将 parameters() 的实现覆盖为迭代器版本。

也许在某种时候这样替换是有用的,但是对我们来说没啥用。

这里直接将 ModuleMisc 里面对于 parameters() 的覆盖删除掉,确保返回值是一个 list

class ModuleMisc:
#    def parameters(self):
#        return iter(super().parameters())

修改完之后再回头运行前面的 test_jittor.py(注意把改掉的 jittor.nn.Module 换回 torch.nn.Module),你会发现它跑起来了!

省流总结

jtorch/__init__.py 中,class ModuleMisc 下对于 parameters() 的覆盖注释掉,确保 parameters() 的返回值是一个 list 而非 list_operator

运行 P-tuning-v2

这里是我在运行我的 P-tuning-v2 遇到的问题,很多都是 transformers 库与 jittor 不兼容带来的问题。

即使 jtorch 对其做了不少暴力替换,transformers_jittor 也处理了一小部分问题,依然还是存在大量需要去操作的地方。

from torch._C import NoopLogger

报错信息如下所示:

Traceback (most recent call last):
  File "run.py", line 101, in <module>
    from tasks.superglue.get_trainer import get_trainer
  File "/home/aiuser/pt5/tasks/superglue/get_trainer.py", line 11, in <module>
    from model.utils import get_model, TaskType
  File "/home/aiuser/pt5/model/utils.py", line 3, in <module>
    from model.sequence_classification import (
  File "/home/aiuser/pt5/model/sequence_classification.py", line 3, in <module>
    from torch._C import NoopLogger
ImportError: cannot import name 'NoopLogger' from 'jtorch.misc' (/home/aiuser/anaconda3/envs/new_jittor/lib/python3.8/site-packages/jtorch/misc.py)

原因是在 jtorch 库中覆盖了 torch._Cjtorch/misc,但是里面没有实现 NoopLogger 类。

直接在 jtorch/misc.py 中添加对于 NoopLogger 的实现:

class NoopLogger:
    def info(self, *args, **kwargs):
        pass  # 不执行任何操作

    def debug(self, *args, **kwargs):
        pass  # 不执行任何操作

    def warning(self, *args, **kwargs):
        pass  # 不执行任何操作

    def error(self, *args, **kwargs):
        pass  # 不执行任何操作

transformers 内部 get_parameter_device()

报错信息如下所示:

Traceback (most recent call last):
  File "run.py", line 128, in <module>
    train(trainer, training_args.resume_from_checkpoint, last_checkpoint)
  File "run.py", line 29, in train
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/home/aiuser/pt5/training/trainer_base_2.py", line 346, in train
    outputs = self.model(
  File "/home/aiuser/anaconda3/envs/new_jittor/lib/python3.8/site-packages/jtorch/nn/__init__.py", line 25, in __call__
    return self.execute(*args, **kw)
  File "/home/aiuser/pt5/model/sequence_classification.py", line 479, in execute
    past_key_values = self.get_prompt(batch_size=batch_size)
  File "/home/aiuser/pt5/model/sequence_classification.py", line 439, in get_prompt
    prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.glm.device)
  File "/home/aiuser/anaconda3/envs/new_jittor/lib/python3.8/site-packages/transformers/modeling_utils.py", line 736, in device
    return get_parameter_device(self)
  File "/home/aiuser/anaconda3/envs/new_jittor/lib/python3.8/site-packages/transformers/modeling_utils.py", line 150, in get_parameter_device
    return next(parameter.parameters()).device
TypeError: 'list' object is not an iterator

直接找到 transformers 库源码的实现:

def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
    try:
        return next(parameter.parameters()).device

根据我们之前所展示的,我们统一认为 parameters() 的返回值是 list,所以这里直接把 next(parameter.parameters()) 改为 parameters()[0] 即可。

修改后如下:

def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
    try:
        return (parameter.parameters())[0].device

self.weight = Parameter(torch.Tensor(self.num_embeddings, self.embedding_dim))

这里报错的原因在于,jtorch 会将其替换为 jtorch.Var(..., ...),但是 Var 并不支持这样初始化,要替换为 randn

解决方法可以是直接修改运行代码:

if torch.__version__ == '2.0.0':
	import jittor
    self.weight = Parameter(jittor.randn(self.num_embeddings, self.embedding_dim))
else:
    self.weight = Parameter(torch.Tensor(self.num_embeddings, self.embedding_dim))

这样可以在不破坏 torch 运行的前提下让 jtorch 环境也能跑。

from transformers import AdamW

不要从 transformers 里面调用它的优化器,直接调用 jittor 自带的。

同理地,运行代码里面修改为修改为:

if torch.__version__ != '2.0.0':
    from transformers import AdamW
else:
    from jittor.nn import AdamW

--fp16

报错如下:

Traceback (most recent call last):
  File "run.py", line 71, in <module>
    args = get_args()
  File "/home/aiuser/pt5/arguments.py", line 197, in get_args
    args = parser.parse_args_into_dataclasses()
  File "/home/aiuser/anaconda3/envs/new_jittor/lib/python3.8/site-packages/transformers/hf_argparser.py", line 332, in parse_args_into_dataclasses
    obj = dtype(**inputs)
  File "<string>", line 108, in __init__
  File "/home/aiuser/anaconda3/envs/new_jittor/lib/python3.8/site-packages/transformers/training_args.py", line 1176, in __post_init__
    raise ValueError(
ValueError: FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation (`--fp16_full_eval`) can only be used on CUDA devices.

原因是 transformers 里面没检测到 cuda 的存在性。

直接进 training_args.py 的源码看看:

        if (
            self.framework == "pt"
            and is_torch_available()
            and (self.device.type != "cuda")
            and (get_xla_device_type(self.device) != "GPU")
            and (self.fp16 or self.fp16_full_eval)
        ):
            raise ValueError(
                "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
                " (`--fp16_full_eval`) can only be used on CUDA devices."
            )

这里的 self.device.typecuda:0,所以稍微修改一下如下所示:

        if (
            self.framework == "pt"
            and is_torch_available()
            and (self.device.type != "cuda" and self.device.type != "cuda:0")
            and (get_xla_device_type(self.device) != "GPU")
            and (self.fp16 or self.fp16_full_eval)
        ):
            raise ValueError(
                "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
                " (`--fp16_full_eval`) can only be used on CUDA devices."
            )

然后就能跑通了。

1 个赞