第六章:PyTorch进阶训练技巧
"""Note: output/target 任意形状,但 dtype 必须可做减法。"""关键别用mean()自带广播且免于手动除;该函数不会被注册到,因此一般在训练循环里直接。"""每 `drop_every` epoch ×0.5"""train(...)
·
第六章:PyTorch进阶训练技巧 — 深入浅出PyTorch datawhale AI共学
6.1 自定义损失函数(两种写法)
6.1.1 函数式 —— 最简,可读性有限
def mse_loss_func(output: torch.Tensor,
target : torch.Tensor) -> torch.Tensor:
"""
Note: output/target 任意形状,但 dtype 必须可做减法。
"""
loss = ((output - target) ** 2).mean()
return loss
关键:
别用
torch.sum((x-y)**2) / n
,mean()
自带广播且免于手动除;该函数 不会 被注册到
model.parameters()
,因此一般在训练循环里直接loss_fn(out, y)
。
6.1.2 类式 —— 推荐,与 nn
生态结合
import torch, torch.nn as nn, torch.nn.functional as F
class DiceLoss(nn.Module):
r"""Dice = 2·|A∩B| / (|A|+|B|),常用于二值分割"""
def __init__(self, smooth: float = 1.0):
super().__init__()
self.smooth = smooth
def forward(self,
logits : torch.Tensor, # (N,C,H,W) 但常见单通道
target : torch.Tensor) -> torch.Tensor:
"""logits => sigmoid => flatten => Dice"""
probs = torch.sigmoid(logits)
probs = probs.view(-1)
target = target.view(-1).float()
inter = (probs * target).sum()
dice = (2. * inter + self.smooth) / \
(probs.sum() + target.sum() + self.smooth)
return 1. - dice # 要 **最小化** 损失 → 1-Dice
为何继承 nn.Module
?
-
能被自动加入
.to(device)
、.half()
、state_dict()
管理 -
可叠加到复合 Loss(多任务 / 权重求和)
6.1.3 组合 Loss 示例:BCE + Dice
class BCEDice(nn.Module):
def __init__(self, alpha=0.5):
"""
alpha=0.5 -> 两个 loss 各占一半;可根据实验调权重
"""
super().__init__()
self.alpha = alpha
self.bce = nn.BCEWithLogitsLoss()
self.dice = DiceLoss()
def forward(self, logits, target):
return (1-self.alpha) * self.bce(logits, target) \
+ self.alpha * self.dice(logits, target)
6.2 动态学习率 —— Scheduler
6.2.1 官方 scheduler 使用
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[30, 60, 90], # 在 30/60/90 epoch 触发衰减
gamma=0.1 # 每次乘 0.1
)
for epoch in range(100):
train_one_epoch(...)
validate(...)
optimizer.step() # ← 必须先更新参数
scheduler.step() # ← 再更新 lr(除了 ReduceLROnPlateau)
ReduceLROnPlateau 例外:要把
scheduler.step(val_loss)
放在 epoch 最后调 lr 时最好
get_last_lr()
打印,防止忘记调用
6.2.2 完全自定义策略
def adjust_lr(optimizer, epoch, base_lr=3e-4, drop_every=10):
"""每 `drop_every` epoch ×0.5"""
lr = base_lr * (0.5 ** (epoch // drop_every))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
for epoch in range(100):
train(...)
adjust_lr(optimizer, epoch)
6.3 模型微调 torchvision
6.3.1 替换分类头 & 冻结 Backbone
import torchvision.models as models, torch.nn as nn
net = models.resnet34(weights='DEFAULT') # PyTorch ≥2.0 推荐写法
for p in net.parameters(): # 冻结全部
p.requires_grad_(False)
num_ftrs = net.fc.in_features # 512
net.fc = nn.Sequential(
nn.Linear(num_ftrs, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(128, 4) # 假设4类
)
优化器只喂可训练层:
trainable = filter(lambda p: p.requires_grad, net.parameters())
optimizer = torch.optim.Adam(trainable, lr=1e-3)
6.4 半精度训练(AMP)
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for img, label in loader:
optimizer.zero_grad()
with autocast(): # 自动混精度
out = model(img.cuda())
loss = criterion(out, label.cuda())
scaler.scale(loss).backward() # 梯度放大防 underflow
scaler.step(optimizer) # 更新权重
scaler.update() # 动态调整 scale
-
无需改
model.float16()
,AMP 会在计算图级别插混 FP16/FP32 -
scaler.update()
失败率高时(梯度溢出)会自动降低 scale
6.5 Imgaug 快速上手
import imgaug.augmenters as iaa, imageio, numpy as np, torch
# (1) 定义 pipeline
aug = iaa.Sequential([
iaa.Fliplr(0.5), # 随机水平翻转
iaa.Crop(percent=(0, 0.1)), # 随机裁剪 0~10%
iaa.LinearContrast((0.8, 1.2)), # 对比度
])
# (2) Dataset 内调用
class ImgDataset(torch.utils.data.Dataset):
def __init__(self, paths, transform):
self.paths, self.transform = paths, transform
def __getitem__(self, idx):
img = imageio.imread(self.paths[idx])
img = self.transform(image=img) # 👈 关键是写作 image=
img = torch.from_numpy(img).permute(2,0,1).float()/255.
return img
def __len__(self): return len(self.paths)
多进程安全:在
DataLoader(..., worker_init_fn=seed_fn)
里def seed_fn(worker_id): iaa.seed(np.random.randint(0, 1e6) + worker_id)
6.6 Argparse 管理超参数
# config.py
import argparse, yaml, json, pathlib
def parse_cfg():
p = argparse.ArgumentParser()
p.add_argument('--cfg', type=str, help='yaml/json file to override')
p.add_argument('--lr', type=float, default=3e-4)
p.add_argument('--epoch', type=int, default=100)
p.add_argument('--bs', type=int, default=32)
opt = p.parse_args()
# ---------- 允许外部文件覆盖 ----------
if opt.cfg:
cfg_path = pathlib.Path(opt.cfg)
with cfg_path.open() as f:
ext = cfg_path.suffix
extra = yaml.safe_load(f) if ext=='.yaml' else json.load(f)
for k, v in extra.items():
setattr(opt, k, v)
return opt
运行:
python train.py --lr 1e-3 --bs 16 \
--cfg custom.yaml # yaml 内可以写 {"epoch": 50}
记忆
需求 | 最短范式 |
---|---|
自定义 Loss | class MyLoss(nn.Module): def forward(self, …): … |
Scheduler | scheduler = torch.optim.lr_scheduler.* |
微调 | for p in net.parameters(): p.requires_grad_(False) |
混精度 | with autocast(): loss = …; scaler.scale(loss).backward() |
Imgaug | aug = iaa.Sequential([...]); img = aug(image=img) |
命令行参数 | argparse + --cfg external.yaml |
更多推荐
所有评论(0)