如何仅通过在 python 代码内设置 gpu 个数来选择 多卡训练呢?以及通过在代码内设置相应参数使得 jittor 在多卡训练时能够使用zero等的一些优化?

我浏览了计图的官方的 MPI 多卡分布式教程后,看到官方是说明直接通过命令行进行启动。

但是目前我们在进行一项类似于 pytorch lightning 的开发,希望仅通过在 python 命令行内设置相应的 flag 或者说参数,就能启动多卡分布式训练,伪代码如下所示:

import jittor as jt
# model part

# for example, could we use like this?
jt.mpi_training = 1

# train part
1 Like

感谢您的关注,这个是可以的,我写了一个example共你参考:jittor/test_mpi_in_py.py at master · Jittor/jittor · GitHub

1 Like

后续有问题欢迎随时联系 :grinning_face_with_smiling_eyes:

好的好的,非常感谢 :grinning: