【动手学UNet】(7)主程序
【动手学UNet】系列教程提供完整的UNet图像分割实现指南,包含7个核心模块:项目创建、数据加载、模型实现、训练、保存/加载、评估和主程序。主程序(main.py)作为统一入口,支持训练、评估和单图推理三种模式,可通过命令行参数灵活配置。训练支持断点续训和烟雾测试;评估可计算Dice/IoU指标并保存结果;单图推理支持原图/真值/预测的三连图展示。
欢迎关注『youcans动手学 AI』系列
【动手学UNet】(1)创建 UNet项目
【动手学UNet】(2)数据加载
【动手学UNet】(3)UNet 模型的实现
【动手学UNet】(4)UNet 模型的训练
【动手学UNet】(5)保存,加载与可视化
【动手学UNet】(6)模型推理与评估
【动手学UNet】(7)主程序
【动手学UNet】(7)主程序
欢迎来到【动手学UNet】系列教程!
本系列将带你从零开始,一步步深入理解和实现经典的UNet图像分割模型。无论你是深度学习初学者,还是有一定经验的开发者,这个系列都将为你提供全面而实用的UNet知识。
12. 主程序(main.py)
main.py 的定位是实现统一调度训练 / 验证 / 单张推理的入口,也可以统一管理参数(如epochs、max_batches等)。
通过命令行参数,可以调度各个功能模块,例如:
- python main.py train → 调用 core/train_unet.py 里的 train(),进行训练(可指定 epochs / max_batches);
- python main.py eval → 调用 core/test_unet.py 的评估逻辑,计算 Dice / IoU,并在 runs/test/ 下保存若干三连图;
- python main.py infer --img xxx --mask yyy → 对一张指定图片做推理,弹出三连图窗口(原图 / 真值 / 预测),并可选择是否保存 png。
main.py的完整代码如下。
# main.py
"""
Unet_retina 项目统一入口脚本
用法示例(在项目根目录):
1. 训练(从头开始,使用 config.py 中的 cfg.epochs):
python main.py train
2. 训练(指定 epoch 数,快速烟雾测试,只跑前 2 个 batch):
python main.py train --epochs 1 --max-batches 2
3. 从已有 checkpoint (weights/last_model.pth) 断点续训:
python main.py train --epochs 100 --resume
4. 在验证集上评估模型,并保存前 5 张样本的三连图:
python main.py eval --max-batches 10 --save-first-n 5
5. 对单张图像做推理,并弹出三连图窗口(可选 mask):
python main.py infer --img dataset/valid/images/example.png --mask dataset/valid/masks/example.png
# 如没有真值掩膜,可省略 --mask:
python main.py infer --img some_image.png
6. 单张推理时只保存、不显示:
python main.py infer --img xxx.png --no-show
7. 单张推理时只显示、不保存:
python main.py infer --img xxx.png --no-save
"""
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
import torch
from core.config import cfg
from core.train_unet import train as train_unet
from core.test_unet import evaluate_and_visualize
from model.unet import UNet
from core.checkpoint import load_model
from utils.visualization import save_prediction_triplet
# === 单张推理相关的辅助函数(与 test_single_image.py 的逻辑一致) ===
def preprocess_image(img_path: Path) -> torch.Tensor:
"""
按与 RetinaDataset 一致的方式对单张图像做预处理:
- 读取 PIL Image
- 转为灰度或 RGB(取决于 cfg.in_channels)
- resize 到 cfg.img_size
- 归一化到 [0,1]
- 转为 [1, C, H, W] 的 tensor(带 batch 维)
"""
img = Image.open(img_path)
if cfg.in_channels == 1:
img = img.convert("L")
elif cfg.in_channels == 3:
img = img.convert("RGB")
else:
raise ValueError(f"Unsupported in_channels in cfg: {cfg.in_channels}")
H, W = cfg.img_size
img = img.resize((W, H), resample=Image.BILINEAR)
img_np = np.array(img, dtype=np.float32) / 255.0
if cfg.in_channels == 1:
# [H, W] -> [1, H, W]
img_np = np.expand_dims(img_np, axis=0)
else:
# [H, W, 3] -> [3, H, W]
img_np = np.transpose(img_np, (2, 0, 1))
img_tensor = torch.from_numpy(img_np).unsqueeze(0) # [1, C, H, W]
return img_tensor
def preprocess_mask(mask_path: Path) -> torch.Tensor:
"""
读取单张掩膜并预处理为 [1, 1, H, W] 的 float tensor,值为 0/1。
若没有真值掩膜,可不调用该函数。
"""
mask = Image.open(mask_path).convert("L")
H, W = cfg.img_size
mask = mask.resize((W, H), resample=Image.NEAREST)
mask_np = np.array(mask, dtype=np.float32)
mask_np = (mask_np > 127.5).astype(np.float32) # 二值化到 0/1
mask_np = np.expand_dims(mask_np, axis=0) # [1, H, W]
mask_tensor = torch.from_numpy(mask_np).unsqueeze(0) # [1, 1, H, W]
return mask_tensor
def load_trained_unet() -> UNet:
"""
加载训练好的 UNet 权重。
优先尝试 best_model.pth,如不存在则使用 last_model.pth。
"""
device = cfg.device
model = UNet(
in_channels=cfg.in_channels,
num_classes=cfg.num_classes,
base_channels=64,
bilinear=True,
).to(device)
if cfg.best_model_path.exists():
ckpt_path = cfg.best_model_path
elif cfg.last_model_path.exists():
ckpt_path = cfg.last_model_path
else:
raise FileNotFoundError(
f"No checkpoint found. Expected one of:\n"
f" best_model: {cfg.best_model_path}\n"
f" last_model: {cfg.last_model_path}"
)
model, _, epoch = load_model(ckpt_path, model, optimizer=None)
print(f"[Main-Infer] Loaded checkpoint from {ckpt_path}")
if epoch is not None:
print(f"[Main-Infer] Checkpoint epoch: {epoch}")
model.eval()
return model
# ========== 子命令对应的执行函数 ==========
def run_train(args):
"""
执行训练流程:
- epochs: 训练轮数(为 None 时使用 cfg.epochs)
- max_batches: 每个 epoch 最多训练多少个 batch(用于烟雾测试)
- resume: 是否从已有 checkpoint (last_model.pth) 断点续训
"""
print(">>> Running TRAIN ...")
train_unet(
num_epochs=args.epochs,
max_batches=args.max_batches,
resume=args.resume,
)
def run_eval(args):
"""
执行验证/评估流程:
- max_batches: 最多评估多少个 batch
- save_first_n: 保存前多少张样本的三连图
"""
print(">>> Running EVAL ...")
evaluate_and_visualize(
max_batches=args.max_batches,
save_first_n=args.save_first_n,
)
def run_single_infer(args):
"""
单张图像推理流程:
- img: 必选,图片路径
- mask: 可选,真值掩膜路径(若无,可不提供)
- no_save: 仅显示,不保存 png
- no_show: 仅保存,不弹出窗口
"""
device = cfg.device
img_path = Path(args.img)
mask_path = Path(args.mask) if args.mask is not None else None
if not img_path.exists():
raise FileNotFoundError(f"Image not found: {img_path}")
if mask_path is not None and not mask_path.exists():
print(f"[Warn] Mask path provided but file not found: {mask_path}")
mask_path = None
print(">>> Running SINGLE INFER ...")
print(f"Using device: {device}")
print(f"Image path : {img_path}")
print(f"Mask path : {mask_path if mask_path is not None else 'None'}")
# 1. 预处理图像
img_tensor = preprocess_image(img_path).to(device) # [1, C, H, W]
# 2. 预处理掩膜(如果有),否则用全零掩膜占位
if mask_path is not None:
mask_tensor = preprocess_mask(mask_path).to(device) # [1, 1, H, W]
else:
mask_tensor = torch.zeros(1, 1, cfg.img_size[0], cfg.img_size[1], device=device)
print("[Main-Infer] No ground truth mask provided, using all-zero dummy mask.")
# 3. 加载模型
model = load_trained_unet()
# 4. 前向推理
with torch.no_grad():
logits = model(img_tensor) # [1, 1, H, W]
probs = torch.sigmoid(logits)
preds = (probs > 0.5).float() # 二值预测
# 5. 三连图可视化(根据参数决定是否保存/显示)
img_show = img_tensor[0].cpu()
mask_show = mask_tensor[0].cpu()
pred_show = preds[0].cpu()
cfg.test_runs_dir.mkdir(parents=True, exist_ok=True)
default_save_path = cfg.test_runs_dir / "single_from_main.png"
save_path = None if args.no_save else default_save_path
show = not args.no_show
save_prediction_triplet(
image=img_show,
mask=mask_show,
pred_mask=pred_show,
save_path=save_path,
show=show,
titles=("Image", "GT Mask", "Pred Mask"),
)
print("[Main-Infer] Done.")
# ========== 命令行解析入口 ==========
def build_parser():
parser = argparse.ArgumentParser(description="Unet_retina 项目统一入口")
subparsers = parser.add_subparsers(dest="command", help="子命令(train / eval / infer)")
# ---- train 子命令 ----
train_parser = subparsers.add_parser("train", help="训练 U-Net 模型")
train_parser.add_argument(
"--epochs", type=int, default=None,
help="训练轮数(默认使用 config.py 中的 cfg.epochs)"
)
train_parser.add_argument(
"--max-batches", type=int, default=None,
help="每个 epoch 最多训练的 batch 数(用于烟雾测试)"
)
train_parser.add_argument(
"--resume", action="store_true",
help="是否从已有 checkpoint (last_model.pth) 恢复训练(断点续训)"
)
train_parser.set_defaults(func=run_train)
# ---- eval 子命令 ----
eval_parser = subparsers.add_parser("eval", help="在验证集上评估模型")
eval_parser.add_argument(
"--max-batches", type=int, default=None,
help="最多评估多少个 batch(不指定则评估完整验证集)"
)
eval_parser.add_argument(
"--save-first-n", type=int, default=5,
help="保存前多少张样本的三连图(默认 5)"
)
eval_parser.set_defaults(func=run_eval)
# ---- infer 子命令 ----
infer_parser = subparsers.add_parser("infer", help="对单张图像做推理并可视化")
infer_parser.add_argument(
"--img", type=str, required=True,
help="输入图像路径"
)
infer_parser.add_argument(
"--mask", type=str, default=None,
help="真值掩膜路径(可选)"
)
infer_parser.add_argument(
"--no-save", action="store_true",
help="只显示,不保存可视化 png"
)
infer_parser.add_argument(
"--no-show", action="store_true",
help="只保存,不弹出窗口显示"
)
infer_parser.set_defaults(func=run_single_infer)
return parser
def main():
parser = build_parser()
args = parser.parse_args()
if not hasattr(args, "func"):
# 没有提供子命令时,打印帮助
parser.print_help()
return
args.func(args)
if __name__ == "__main__":
main()
main.py 使用 argparse 实现一个简单的命令行接口,包括三个子命令:train、eval和infer;每个子命令对应一个函数:run_train(args)、run_eval(args)和run_single_infer(args)。
在PyCharm 中切换到终端,检查虚拟环境和项目路径正确后,在命令行中输入如下指令即可运行:
(unet) C:\Python\Projects2025\Unet_retina> python main.py train --epochs 1 --max-batches 2

更方便地,在 PyCharm菜单栏选择 运行→编辑配置 新增参数配置,就可以切换不同运行模式。例如:
Train:
Script:main.py
Parameters:train --epochs 10 --max-batches 4
Train:
Script:main.py
Parameters:train --epochs 500 --resume
Eval:
Script:main.py
Parameters:eval --max-batches 5 --save-first-n 3
SingleInfer:
Script:main.py
Parameters:infer --img dataset/valid/images/example.png --mask dataset/valid/masks/example.png

配置好运行参数,点击工具栏的运行按钮 “run”,就可以选择不同运行模式,进行模型训练或推理。
- 模型训练:可以从头训练或选择断点续训。
注意选择断点续训,需要确保 weights/last_model.pth 已经存在(之前训练过并保存过)。

- 模型推理:


【本节完】
版权声明:
欢迎关注『youcans动手学 AI』系列
转发请注明原文链接:
【动手学UNet】(7)主程序
Copyright 2025 youcans
Crated:2025-12

更多推荐

所有评论(0)