PyTorch 多数据集医学图像分割训练进程"卡死"问题排查与修复

如遇见其它问题始终不能解决,可用咸鱼软件搜索用户:代码跑通pytorch,私信我解答,欢迎关注我的咸鱼用户:代码跑通pytorch

一、问题现象

在使用 PyTorch + MONAI 进行多器官肿瘤分割模型训练时,使用 nohup 后台启动训练脚本,运行到 Epoch 4 的第 614/736 个 batch 时,日志不再更新,训练看起来"卡住"了:

>>>> [正在处理] Dataset:MSD_colon | Index:0 | File:/.../MSD_colon/ct/colon_022_0000.nii.gz
Epoch 4/1000 613/736 seg loss: 0.7505 cls loss: 0.0868 loss: 0.8373 time 1.24s

>>>> [正在处理] Dataset:KiTS23 | Index:94 | File:/.../KiTS23/ct/KiTS23_case_00113_0000.nii.gz
(之后再无输出)

训练前 4 个 epoch(约 1.5 小时)都正常运行,loss 也在下降。初步怀疑是某个特定数据文件导致的问题。


二、排查过程

2.1 查看进程状态

ps aux | grep Main_Ours | grep -v grep | wc -l
# 输出:36

发现有 36 个 Python 进程与训练脚本相关,但实际只启动了一个训练任务。

进一步查看父进程状态:

ps -p 14148 -o pid,stat,cmd
#  PID STAT CMD
# 14148 Zl   [python] <defunct>

主训练进程已变为僵尸进程(<defunct>,这说明进程并非"卡住",而是已经崩溃退出了。

2.2 查看 GPU 状态

nvidia-smi
| GPU 0 | 12194MiB / 24576MiB | 4% | (连接显示器)
| GPU 1 |    13MiB / 24576MiB | 0% | (完全空闲)

GPU 0 占用 12GB 显存但进程列表中没有任何 Python 训练进程——这是进程崩溃后 CUDA 显存泄漏的典型表现。

2.3 分析子进程来源

pstree -p 14148 | head -20
python(14148)-+-python(14191)---{python}(14192)
              |-python(14196)---{python}(14198)
              |-python(14202)---{python}(14205)
              ...(共 10 个子进程)

代码中有如下写法:

args.SHARE_LIST = [mp.Manager().list() for i in range(args.out_channels)]  # out_channels=10

这行代码创建了 10 个独立的 mp.Manager() 服务器进程,每个占 ~527MB 内存。主进程崩溃后,这些子进程变成了孤儿进程仍然存活。

2.4 验证"问题文件"是否真是元凶

搜索日志中 KiTS23_case_00113_0000.nii.gz 的所有出现:

grep -n "KiTS23_case_00113" train.log
3332:  Epoch 1, batch 362→363  ✅ 成功,耗时 2.62s
7326:  Epoch 3, batch 219→220  ✅ 成功,耗时 1.26s
7827:  Epoch 3               ✅ 成功
8622:  Epoch 3               ✅ 成功
...(共 7 次成功)
10719: Epoch 4, batch 614     ❌ 进程在此处被杀死

该文件在之前的 epoch 中成功处理了 7 次,并非文件本身有问题。

查看文件信息:

import SimpleITK as sitk
img = sitk.ReadImage('KiTS23_case_00113_0000.nii.gz')
print(img.GetSize())   # (143, 448, 449)
print(img.GetSpacing()) # (1.5, 0.8, 0.8)
# 总体素:28,764,736,float32 约 110MB,经过 3 通道变换后约 330MB

体积中等,不算特别大。

2.5 检查系统日志

journalctl --since "2026-02-27" | grep -i "oom\|killed"

发现系统 OOM Killer 有活动记录,训练日志中也没有任何 Python traceback——说明进程是被 OS 信号直接杀死(SIGKILL),而非 Python 异常退出。


三、根因分析

直接原因

进程在 Epoch 4 运行约 1.5 小时后,被操作系统杀死(SIGKILL/OOM),并非被某个特定文件"卡住"。

根本原因

因素 说明
mp.Manager() 重复创建 创建了 10 个独立管理器进程,白白占用 ~5GB 内存
GPU 0 显存不足 GPU 0 连接显示器,桌面环境占用 ~400MB,实际可用 < 24GB
显存累积 随机数据增强(RandZoomd 放大到 1.25x)偶尔产生大体积中间结果
无异常保护 训练循环没有 try/except,CUDA OOM 直接导致进程崩溃
数据加载无容错 __getitem__ 中 transform 出错会直接抛出异常终止训练

为什么恰好在 Epoch 4?

由于 __getitem__ 使用 random.uniform()np.random.randint() 进行完全随机采样(忽略传入的 index),每个 epoch 的数据顺序完全不同。这是一个概率事件——显存累积到临界点时,恰好碰到一个较大的样本触发了 OOM。


四、修复方案

4.1 修复 mp.Manager() 重复创建(Main_Ours.py)

修改前(创建 10 个 Manager):

args.SHARE_LIST = [mp.Manager().list() for i in range(args.out_channels)]

修改后(共用 1 个 Manager):

manager = mp.Manager()
args.SHARE_LIST = [manager.list() for i in range(args.out_channels)]

效果:从 10 个 Manager 服务器进程减少为 1 个,节省约 4.5GB 内存,降低孤儿进程风险。

4.2 添加数据加载异常保护(dataloader.py)

UniformDataset__getitem__ 添加重试机制,transform 出错时自动跳过换另一个样本:

def _get_random_sample(self):
    """根据权重随机选择数据集和样本索引"""
    random_num = random.uniform(0, self.weight_sum)
    if random_num == self.weight_sum:
        set_index = len(self.data_dic) - 1
    else:
        for i in range(len(self.points) - 1):
            if self.points[i] <= random_num < self.points[i + 1]:
                set_index = i
                break
    set_key = list(self.data_dic.keys())[set_index]
    data_index = np.random.randint(self.datasetnum[set_index], size=1)[0]
    return set_key, data_index

def __getitem__(self, index):
    MAX_RETRIES = 5
    for attempt in range(MAX_RETRIES):
        set_key, data_index = self._get_random_sample()
        try:
            current_file = self.data_dic[set_key][data_index].get('image', 'unknown')
            print(f"\n>>>> [正在处理] Dataset:{set_key} | Index:{data_index} | "
                  f"File:{current_file}", flush=True)
        except:
            current_file = 'unknown'

        try:
            result = self._transform(set_key, data_index)
            return result
        except Exception as e:
            print(f"\n>>>> [ERROR] Dataset:{set_key} | Index:{data_index} | "
                  f"File:{current_file} | Attempt {attempt+1}/{MAX_RETRIES} | "
                  f"Error: {e}", flush=True)
            if attempt == MAX_RETRIES - 1:
                raise RuntimeError(
                    f"Failed after {MAX_RETRIES} retries. "
                    f"Last error on {current_file}: {e}")
            continue

效果:单个文件的 transform 出错不会终止整个训练,最多重试 5 次后换其他样本。

4.3 添加 CUDA OOM 保护(Trainer_Ours.py)

train_epoch 的训练循环中包裹 try/except,捕获 CUDA 显存溢出错误:

for idx, batch_data in enumerate(loader):
    try:
        # ... 原有的前向传播、loss 计算、反向传播代码 ...

        if args.rank == 0:
            print(
                "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
                "seg loss: {:.4f}".format(run_seg_loss.avg),
                "cls loss: {:.4f}".format(run_cls_loss.avg),
                "loss: {:.4f}".format(run_loss.avg),
                "time {:.2f}s".format(time.time() - start_time),
            )
    except RuntimeError as e:
        if "out of memory" in str(e):
            print(f"\n>>>> [CUDA OOM] Epoch {epoch} batch {idx}, "
                  f"skipping. Error: {e}", flush=True)
            torch.cuda.empty_cache()
            optimizer.zero_grad()
        else:
            raise e
    start_time = time.time()

效果:CUDA OOM 时跳过该 batch 并清理显存缓存,训练继续进行而不是直接崩溃。

4.4 使用空闲 GPU 启动训练

GPU 0 连接显示器,桌面环境占用约 400MB 显存。GPU 1 完全空闲,拥有完整的 24GB 显存。启动命令加上 CUDA_VISIBLE_DEVICES=1

conda activate work2
cd /home/data/LYL/CodeBase/work2/MultiTumorSeg/MultiTumorSeg/main
export PYTHONPATH=$PYTHONPATH:/home/data/LYL/CodeBase/work2/MultiTumorSeg

CUDA_VISIBLE_DEVICES=1 nohup python Main_Ours.py \
    --save_checkpoint \
    --use_normal_dataset \
    --json_list /path/to/KiTS19_MSD_dataset.json \
    --data_dir /path/to/KiTS19_MSD \
    --logdir My_Training_Final_Resume \
    --batch_size 1 \
    --workers 0 \
    --checkpoint /path/to/model_min_loss.pt \
    --use_checkpoint \
    > ./train.log 2>&1 &

五、修复后验证

启动训练后检查 GPU 状态:

nvidia-smi
| GPU 1 | 11614MiB / 24576MiB | 100% | python (PID 28058) |

GPU 1 使用 11.6GB,剩余约 13GB 空闲显存,训练正常运行。查看日志确认 loss 持续下降:

Epoch 0/1000 0/736 seg loss: 1.2585 cls loss: 0.1635 loss: 1.4220 time 11.47s
Epoch 0/1000 1/736 seg loss: 1.1731 cls loss: 0.1206 loss: 1.2937 time 3.06s
Epoch 0/1000 2/736 seg loss: 0.9971 cls loss: 0.0813 loss: 1.0784 time 1.20s
...

六、经验总结

教训 建议
日志无输出 ≠ 卡住 先用 ps 检查进程是否存活,可能已经崩溃
mp.Manager() 不要放在循环里 每次调用都会创建一个服务器进程
多 GPU 机器注意显示器占用 连接显示器的 GPU 可用显存更少,训练应选空闲 GPU
训练循环需要异常保护 try/except RuntimeError 捕获 CUDA OOM
数据加载需要容错 __getitem__ 中的 transform 可能因数据问题失败
修改代码前一定要备份 cp file.py file.py.bak_日期
僵尸进程的含义 <defunct> 表示进程已死,但父进程未回收其退出状态

七、修改文件清单

文件 修改内容 备份
Main_Ours.py mp.Manager() 10→1 Main_Ours.py.bak_20260228
dataloader.py __getitem__ 添加异常重试 dataloader.py.bak_20260228
Trainer_Ours.py 训练循环添加 CUDA OOM 保护 Trainer_Ours.py.bak_20260228
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐