Pytorch 报错 AssertionError: If capturable=False, state_steps should not be CUDA tensors.
【代码】Pytorch 报错 AssertionError: If capturable=False, state_steps should not be CUDA tensors.
·
Pytorch 报错 AssertionError: If capturable=False, state_steps should not be CUDA tensors.
-
具体报错内容如下:
File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 213, in adam File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper return wrapped(*args, **kwargs) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/optimizer.py", line 109, in wrapper func(params, File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 255, in _single_tensor_adam return func(*args, **kwargs) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 157, in step assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors." AssertionError: If capturable=False, state_steps should not be CUDA tensors. adam(params_with_grad, File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 213, in adam func(params, File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 255, in _single_tensor_adam assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors." AssertionError: If capturable=False, state_steps should not be CUDA tensors. ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 4681) of binary: /opt/anaconda3/envs/xxx/bin/python Traceback (most recent call last): File "/opt/anaconda3/envs/xxx/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/anaconda3/envs/xxx/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 193, in <module> main() File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 189, in main launch(args) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 174, in launch run(args) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/run.py", line 752, in run elastic_launch( File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 131, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: -
可能的原因
与1.12.0版本torch新引入的参数capturable有关。 -
运行环境包版本信息:
PyTorch: 1.12.0+cu116 TorchVision: 0.13.0+cu116 -
解决方案
- 加载
checkpoint后设置参数capturable为True:optim.param_groups[0]['capturable'] = True - 执行如下命令,安装
1.12版本的torch及与之对应的torchvision:pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
注:上述命令安装对应
CUDA 11.6的torch及torchvision,亦可安装其他版本的torch,诸如1.10、1.13等。 - 加载
Ref
[1] https://github.com/pytorch/pytorch/issues/80809#issuecomment-1173481031
更多推荐

所有评论(0)