欢迎关注『youcans动手学 AI』系列
【动手学UNet】(1)创建 UNet项目
【动手学UNet】(2)数据加载
【动手学UNet】(3)UNet 模型的实现
【动手学UNet】(4)UNet 模型的训练
【动手学UNet】(5)保存,加载与可视化
【动手学UNet】(6)模型推理与评估
【动手学UNet】(7)主程序


【动手学UNet】(7)主程序


欢迎来到【动手学UNet】系列教程!
本系列将带你从零开始,一步步深入理解和实现经典的UNet图像分割模型。无论你是深度学习初学者,还是有一定经验的开发者,这个系列都将为你提供全面而实用的UNet知识。


12. 主程序(main.py)

main.py 的定位是实现统一调度训练 / 验证 / 单张推理的入口,也可以统一管理参数(如epochs、max_batches等)。
通过命令行参数,可以调度各个功能模块,例如:

  • python main.py train → 调用 core/train_unet.py 里的 train(),进行训练(可指定 epochs / max_batches);
  • python main.py eval → 调用 core/test_unet.py 的评估逻辑,计算 Dice / IoU,并在 runs/test/ 下保存若干三连图;
  • python main.py infer --img xxx --mask yyy → 对一张指定图片做推理,弹出三连图窗口(原图 / 真值 / 预测),并可选择是否保存 png。

main.py的完整代码如下。

# main.py
"""
Unet_retina 项目统一入口脚本
用法示例(在项目根目录):
1. 训练(从头开始,使用 config.py 中的 cfg.epochs):
   python main.py train
2. 训练(指定 epoch 数,快速烟雾测试,只跑前 2 个 batch):
   python main.py train --epochs 1 --max-batches 2
3. 从已有 checkpoint (weights/last_model.pth) 断点续训:
   python main.py train --epochs 100 --resume
4. 在验证集上评估模型,并保存前 5 张样本的三连图:
   python main.py eval --max-batches 10 --save-first-n 5
5. 对单张图像做推理,并弹出三连图窗口(可选 mask):
   python main.py infer --img dataset/valid/images/example.png --mask dataset/valid/masks/example.png
   # 如没有真值掩膜,可省略 --mask:
   python main.py infer --img some_image.png
6. 单张推理时只保存、不显示:
   python main.py infer --img xxx.png --no-show
7. 单张推理时只显示、不保存:
   python main.py infer --img xxx.png --no-save
"""

import argparse
from pathlib import Path
import numpy as np
from PIL import Image
import torch

from core.config import cfg
from core.train_unet import train as train_unet
from core.test_unet import evaluate_and_visualize
from model.unet import UNet
from core.checkpoint import load_model
from utils.visualization import save_prediction_triplet


# === 单张推理相关的辅助函数(与 test_single_image.py 的逻辑一致) ===

def preprocess_image(img_path: Path) -> torch.Tensor:
    """
    按与 RetinaDataset 一致的方式对单张图像做预处理:
    - 读取 PIL Image
    - 转为灰度或 RGB(取决于 cfg.in_channels)
    - resize 到 cfg.img_size
    - 归一化到 [0,1]
    - 转为 [1, C, H, W] 的 tensor(带 batch 维)
    """
    img = Image.open(img_path)

    if cfg.in_channels == 1:
        img = img.convert("L")
    elif cfg.in_channels == 3:
        img = img.convert("RGB")
    else:
        raise ValueError(f"Unsupported in_channels in cfg: {cfg.in_channels}")

    H, W = cfg.img_size
    img = img.resize((W, H), resample=Image.BILINEAR)

    img_np = np.array(img, dtype=np.float32) / 255.0

    if cfg.in_channels == 1:
        # [H, W] -> [1, H, W]
        img_np = np.expand_dims(img_np, axis=0)
    else:
        # [H, W, 3] -> [3, H, W]
        img_np = np.transpose(img_np, (2, 0, 1))

    img_tensor = torch.from_numpy(img_np).unsqueeze(0)  # [1, C, H, W]
    return img_tensor


def preprocess_mask(mask_path: Path) -> torch.Tensor:
    """
    读取单张掩膜并预处理为 [1, 1, H, W] 的 float tensor,值为 0/1。
    若没有真值掩膜,可不调用该函数。
    """
    mask = Image.open(mask_path).convert("L")

    H, W = cfg.img_size
    mask = mask.resize((W, H), resample=Image.NEAREST)

    mask_np = np.array(mask, dtype=np.float32)
    mask_np = (mask_np > 127.5).astype(np.float32)  # 二值化到 0/1

    mask_np = np.expand_dims(mask_np, axis=0)  # [1, H, W]
    mask_tensor = torch.from_numpy(mask_np).unsqueeze(0)  # [1, 1, H, W]
    return mask_tensor


def load_trained_unet() -> UNet:
    """
    加载训练好的 UNet 权重。
    优先尝试 best_model.pth,如不存在则使用 last_model.pth。
    """
    device = cfg.device

    model = UNet(
        in_channels=cfg.in_channels,
        num_classes=cfg.num_classes,
        base_channels=64,
        bilinear=True,
    ).to(device)

    if cfg.best_model_path.exists():
        ckpt_path = cfg.best_model_path
    elif cfg.last_model_path.exists():
        ckpt_path = cfg.last_model_path
    else:
        raise FileNotFoundError(
            f"No checkpoint found. Expected one of:\n"
            f"  best_model: {cfg.best_model_path}\n"
            f"  last_model: {cfg.last_model_path}"
        )

    model, _, epoch = load_model(ckpt_path, model, optimizer=None)
    print(f"[Main-Infer] Loaded checkpoint from {ckpt_path}")
    if epoch is not None:
        print(f"[Main-Infer] Checkpoint epoch: {epoch}")

    model.eval()
    return model


# ========== 子命令对应的执行函数 ==========

def run_train(args):
    """
    执行训练流程:
    - epochs: 训练轮数(为 None 时使用 cfg.epochs)
    - max_batches: 每个 epoch 最多训练多少个 batch(用于烟雾测试)
    - resume: 是否从已有 checkpoint (last_model.pth) 断点续训
    """
    print(">>> Running TRAIN ...")
    train_unet(
        num_epochs=args.epochs,
        max_batches=args.max_batches,
        resume=args.resume,
    )


def run_eval(args):
    """
    执行验证/评估流程:
    - max_batches: 最多评估多少个 batch
    - save_first_n: 保存前多少张样本的三连图
    """
    print(">>> Running EVAL ...")
    evaluate_and_visualize(
        max_batches=args.max_batches,
        save_first_n=args.save_first_n,
    )


def run_single_infer(args):
    """
    单张图像推理流程:
    - img: 必选,图片路径
    - mask: 可选,真值掩膜路径(若无,可不提供)
    - no_save: 仅显示,不保存 png
    - no_show: 仅保存,不弹出窗口
    """
    device = cfg.device
    img_path = Path(args.img)
    mask_path = Path(args.mask) if args.mask is not None else None

    if not img_path.exists():
        raise FileNotFoundError(f"Image not found: {img_path}")

    if mask_path is not None and not mask_path.exists():
        print(f"[Warn] Mask path provided but file not found: {mask_path}")
        mask_path = None

    print(">>> Running SINGLE INFER ...")
    print(f"Using device: {device}")
    print(f"Image path  : {img_path}")
    print(f"Mask path   : {mask_path if mask_path is not None else 'None'}")

    # 1. 预处理图像
    img_tensor = preprocess_image(img_path).to(device)  # [1, C, H, W]

    # 2. 预处理掩膜(如果有),否则用全零掩膜占位
    if mask_path is not None:
        mask_tensor = preprocess_mask(mask_path).to(device)  # [1, 1, H, W]
    else:
        mask_tensor = torch.zeros(1, 1, cfg.img_size[0], cfg.img_size[1], device=device)
        print("[Main-Infer] No ground truth mask provided, using all-zero dummy mask.")

    # 3. 加载模型
    model = load_trained_unet()

    # 4. 前向推理
    with torch.no_grad():
        logits = model(img_tensor)           # [1, 1, H, W]
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()        # 二值预测

    # 5. 三连图可视化(根据参数决定是否保存/显示)
    img_show = img_tensor[0].cpu()
    mask_show = mask_tensor[0].cpu()
    pred_show = preds[0].cpu()

    cfg.test_runs_dir.mkdir(parents=True, exist_ok=True)
    default_save_path = cfg.test_runs_dir / "single_from_main.png"

    save_path = None if args.no_save else default_save_path
    show = not args.no_show

    save_prediction_triplet(
        image=img_show,
        mask=mask_show,
        pred_mask=pred_show,
        save_path=save_path,
        show=show,
        titles=("Image", "GT Mask", "Pred Mask"),
    )

    print("[Main-Infer] Done.")

# ========== 命令行解析入口 ==========

def build_parser():
    parser = argparse.ArgumentParser(description="Unet_retina 项目统一入口")
    subparsers = parser.add_subparsers(dest="command", help="子命令(train / eval / infer)")

    # ---- train 子命令 ----
    train_parser = subparsers.add_parser("train", help="训练 U-Net 模型")
    train_parser.add_argument(
        "--epochs", type=int, default=None,
        help="训练轮数(默认使用 config.py 中的 cfg.epochs)"
    )
    train_parser.add_argument(
        "--max-batches", type=int, default=None,
        help="每个 epoch 最多训练的 batch 数(用于烟雾测试)"
    )
    train_parser.add_argument(
        "--resume", action="store_true",
        help="是否从已有 checkpoint (last_model.pth) 恢复训练(断点续训)"
    )
    train_parser.set_defaults(func=run_train)

    # ---- eval 子命令 ----
    eval_parser = subparsers.add_parser("eval", help="在验证集上评估模型")
    eval_parser.add_argument(
        "--max-batches", type=int, default=None,
        help="最多评估多少个 batch(不指定则评估完整验证集)"
    )
    eval_parser.add_argument(
        "--save-first-n", type=int, default=5,
        help="保存前多少张样本的三连图(默认 5)"
    )
    eval_parser.set_defaults(func=run_eval)

    # ---- infer 子命令 ----
    infer_parser = subparsers.add_parser("infer", help="对单张图像做推理并可视化")
    infer_parser.add_argument(
        "--img", type=str, required=True,
        help="输入图像路径"
    )
    infer_parser.add_argument(
        "--mask", type=str, default=None,
        help="真值掩膜路径(可选)"
    )
    infer_parser.add_argument(
        "--no-save", action="store_true",
        help="只显示,不保存可视化 png"
    )
    infer_parser.add_argument(
        "--no-show", action="store_true",
        help="只保存,不弹出窗口显示"
    )
    infer_parser.set_defaults(func=run_single_infer)

    return parser


def main():
    parser = build_parser()
    args = parser.parse_args()

    if not hasattr(args, "func"):
        # 没有提供子命令时,打印帮助
        parser.print_help()
        return

    args.func(args)


if __name__ == "__main__":
    main()

main.py 使用 argparse 实现一个简单的命令行接口,包括三个子命令:train、eval和infer;每个子命令对应一个函数:run_train(args)、run_eval(args)和run_single_infer(args)。

在PyCharm 中切换到终端,检查虚拟环境和项目路径正确后,在命令行中输入如下指令即可运行:

(unet) C:\Python\Projects2025\Unet_retina> python main.py train --epochs 1 --max-batches 2

在这里插入图片描述


更方便地,在 PyCharm菜单栏选择 运行→编辑配置 新增参数配置,就可以切换不同运行模式。例如:

Train:
Script:main.py
Parameters:train --epochs 10 --max-batches 4

Train:
Script:main.py
Parameters:train --epochs 500 --resume

Eval:
Script:main.py
Parameters:eval --max-batches 5 --save-first-n 3

SingleInfer:
Script:main.py
Parameters:infer --img dataset/valid/images/example.png --mask dataset/valid/masks/example.png

在这里插入图片描述


配置好运行参数,点击工具栏的运行按钮 “run”,就可以选择不同运行模式,进行模型训练或推理。

  1. 模型训练:可以从头训练或选择断点续训。
    注意选择断点续训,需要确保 weights/last_model.pth 已经存在(之前训练过并保存过)。

在这里插入图片描述


  1. 模型推理:

在这里插入图片描述

在这里插入图片描述

【本节完】


版权声明:
欢迎关注『youcans动手学 AI』系列
转发请注明原文链接:
【动手学UNet】(7)主程序

Copyright 2025 youcans
Crated:2025-12


在这里插入图片描述

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐