Pytorch 报错 AssertionError: If capturable=False, state_steps should not be CUDA tensors.

  1. 具体报错内容如下:

      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 213, in adam
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper
        return wrapped(*args, **kwargs)
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/optimizer.py", line 109, in wrapper
        func(params,
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 255, in _single_tensor_adam
        return func(*args, **kwargs)
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
        return func(*args, **kwargs)
          File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 157, in step
    assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors."
    AssertionError: If capturable=False, state_steps should not be CUDA tensors.
        adam(params_with_grad,
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 213, in adam
        func(params,
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/optim/adam.py", line 255, in _single_tensor_adam
        assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors."
    AssertionError: If capturable=False, state_steps should not be CUDA tensors.
    ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 4681) of binary: /opt/anaconda3/envs/xxx/bin/python
    Traceback (most recent call last):
      File "/opt/anaconda3/envs/xxx/lib/python3.10/runpy.py", line 196, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/opt/anaconda3/envs/xxx/lib/python3.10/runpy.py", line 86, in _run_code
        exec(code, run_globals)
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 193, in <module>
        main()
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 189, in main
        launch(args)
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launch.py", line 174, in launch
        run(args)
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/run.py", line 752, in run
        elastic_launch(
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
        return launch_agent(self._config, self._entrypoint, list(args))
      File "/opt/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
        raise ChildFailedError(
    torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
    
  2. 可能的原因
    1.12.0 版本 torch 新引入的参数 capturable有关。

  3. 运行环境包版本信息:

    PyTorch: 1.12.0+cu116
    TorchVision: 0.13.0+cu116
    
  4. 解决方案

    • 加载 checkpoint 后设置参数 capturableTrue
      optim.param_groups[0]['capturable'] = True
      
    • 执行如下命令,安装 1.12 版本的 torch 及与之对应的 torchvision
      pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116	
      

    注:上述命令安装对应 CUDA 11.6torchtorchvision,亦可安装其他版本的 torch,诸如 1.101.13 等。


Ref

[1] https://github.com/pytorch/pytorch/issues/80809#issuecomment-1173481031

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐