第六章:PyTorch进阶训练技巧 — 深入浅出PyTorch    datawhale AI共学


6.1 自定义损失函数(两种写法)

6.1.1 函数式 —— 最简,可读性有限

def mse_loss_func(output: torch.Tensor,
                  target : torch.Tensor) -> torch.Tensor:
    """
    Note: output/target 任意形状,但 dtype 必须可做减法。
    """
    loss = ((output - target) ** 2).mean()
    return loss

关键

  1. 别用 torch.sum((x-y)**2) / nmean() 自带广播且免于手动除;

  2. 该函数 不会 被注册到 model.parameters(),因此一般在训练循环里直接 loss_fn(out, y)


6.1.2 类式 —— 推荐,与 nn 生态结合

import torch, torch.nn as nn, torch.nn.functional as F

class DiceLoss(nn.Module):
    r"""Dice = 2·|A∩B| / (|A|+|B|),常用于二值分割"""
    def __init__(self, smooth: float = 1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self,
                logits : torch.Tensor,   # (N,C,H,W) 但常见单通道
                target : torch.Tensor) -> torch.Tensor:
        """logits => sigmoid => flatten => Dice"""
        probs   = torch.sigmoid(logits)
        probs   = probs.view(-1)
        target  = target.view(-1).float()

        inter   = (probs * target).sum()
        dice    = (2. * inter + self.smooth) / \
                  (probs.sum() + target.sum() + self.smooth)
        return 1. - dice       # 要 **最小化** 损失 → 1-Dice
为何继承 nn.Module
  • 能被自动加入 .to(device).half()state_dict() 管理

  • 可叠加到复合 Loss(多任务 / 权重求和)


6.1.3 组合 Loss 示例:BCE + Dice

class BCEDice(nn.Module):
    def __init__(self, alpha=0.5):
        """
        alpha=0.5 -> 两个 loss 各占一半;可根据实验调权重
        """
        super().__init__()
        self.alpha  = alpha
        self.bce    = nn.BCEWithLogitsLoss()
        self.dice   = DiceLoss()

    def forward(self, logits, target):
        return (1-self.alpha) * self.bce(logits, target) \
               +     self.alpha  * self.dice(logits, target)

6.2 动态学习率 —— Scheduler

6.2.1 官方 scheduler 使用

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[30, 60, 90],    # 在 30/60/90 epoch 触发衰减
    gamma=0.1                   # 每次乘 0.1
)

for epoch in range(100):
    train_one_epoch(...)
    validate(...)

    optimizer.step()    # ← 必须先更新参数
    scheduler.step()    # ← 再更新 lr(除了 ReduceLROnPlateau)
  • ReduceLROnPlateau 例外:要把 scheduler.step(val_loss) 放在 epoch 最后

  • 调 lr 时最好 get_last_lr() 打印,防止忘记调用


6.2.2 完全自定义策略

def adjust_lr(optimizer, epoch, base_lr=3e-4, drop_every=10):
    """每 `drop_every` epoch ×0.5"""
    lr = base_lr * (0.5 ** (epoch // drop_every))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
for epoch in range(100):
    train(...)
    adjust_lr(optimizer, epoch)

6.3 模型微调 torchvision

6.3.1 替换分类头 & 冻结 Backbone

import torchvision.models as models, torch.nn as nn

net = models.resnet34(weights='DEFAULT')   # PyTorch ≥2.0 推荐写法
for p in net.parameters():                 # 冻结全部
    p.requires_grad_(False)

num_ftrs      = net.fc.in_features         # 512
net.fc = nn.Sequential(
    nn.Linear(num_ftrs, 128),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(128, 4)                      # 假设4类
)

优化器只喂可训练层

trainable = filter(lambda p: p.requires_grad, net.parameters())
optimizer  = torch.optim.Adam(trainable, lr=1e-3)

6.4 半精度训练(AMP)

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for img, label in loader:
    optimizer.zero_grad()
    with autocast():                       # 自动混精度
        out  = model(img.cuda())
        loss = criterion(out, label.cuda())

    scaler.scale(loss).backward()          # 梯度放大防 underflow
    scaler.step(optimizer)                 # 更新权重
    scaler.update()                        # 动态调整 scale
  • 无需model.float16(),AMP 会在计算图级别插混 FP16/FP32

  • scaler.update() 失败率高时(梯度溢出)会自动降低 scale


6.5 Imgaug 快速上手

import imgaug.augmenters as iaa, imageio, numpy as np, torch

# (1) 定义 pipeline
aug = iaa.Sequential([
    iaa.Fliplr(0.5),                       # 随机水平翻转
    iaa.Crop(percent=(0, 0.1)),            # 随机裁剪 0~10%
    iaa.LinearContrast((0.8, 1.2)),        # 对比度
])

# (2) Dataset 内调用
class ImgDataset(torch.utils.data.Dataset):
    def __init__(self, paths, transform):
        self.paths, self.transform = paths, transform
    def __getitem__(self, idx):
        img = imageio.imread(self.paths[idx])
        img = self.transform(image=img)    # 👈 关键是写作 image=
        img = torch.from_numpy(img).permute(2,0,1).float()/255.
        return img
    def __len__(self): return len(self.paths)

多进程安全:在 DataLoader(..., worker_init_fn=seed_fn)

def seed_fn(worker_id):
    iaa.seed(np.random.randint(0, 1e6) + worker_id)

6.6 Argparse 管理超参数

# config.py
import argparse, yaml, json, pathlib

def parse_cfg():
    p = argparse.ArgumentParser()
    p.add_argument('--cfg', type=str, help='yaml/json file to override')
    p.add_argument('--lr',  type=float, default=3e-4)
    p.add_argument('--epoch', type=int, default=100)
    p.add_argument('--bs',  type=int, default=32)
    opt = p.parse_args()

    # ---------- 允许外部文件覆盖 ----------
    if opt.cfg:
        cfg_path = pathlib.Path(opt.cfg)
        with cfg_path.open() as f:
            ext = cfg_path.suffix
            extra = yaml.safe_load(f) if ext=='.yaml' else json.load(f)
        for k, v in extra.items():
            setattr(opt, k, v)

    return opt

运行:

python train.py --lr 1e-3 --bs 16 \
                --cfg custom.yaml      # yaml 内可以写 {"epoch": 50}

记忆

需求 最短范式
自定义 Loss class MyLoss(nn.Module): def forward(self, …): …
Scheduler scheduler = torch.optim.lr_scheduler.*
微调 for p in net.parameters(): p.requires_grad_(False)
混精度 with autocast(): loss = …; scaler.scale(loss).backward()
Imgaug aug = iaa.Sequential([...]); img = aug(image=img)
命令行参数 argparse + --cfg external.yaml

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐