深度学习中的早停机制
而早停的操作在于通过监控一个验证指标(如验证集损失、准确率等),当该指标在一定数量的训练周期内不再改善时,提前终止训练。模型在训练集上的性能通常会随着训练的进行而提升,但在验证集上可能会出现先上升后下降的情况,这意味着模型开始过拟合。4. 保存最佳模型:每次有明显改善的时候,就记录model的参数,确保早停时模型能够重新载入最新的最佳的模型参数。2. 设置delta参数:如果指标的改善在delta
·
目的:防止模型过拟合。
模型在训练集上的性能通常会随着训练的进行而提升,但在验证集上可能会出现先上升后下降的情况,这意味着模型开始过拟合。
而早停的操作在于通过监控一个验证指标(如验证集损失、准确率等),当该指标在一定数量的训练周期内不再改善时,提前终止训练。
内容:
1. 设置patience参数:代表指标没有改善的最大训练周期数。
2. 设置delta参数:如果指标的改善在delta阈值以内,可以认为没有改善。
3. 监控指标:判断指标有没有朝着目标改善至少delta。
4. 保存最佳模型:每次有明显改善的时候,就记录model的参数,确保早停时模型能够重新载入最新的最佳的模型参数。
代码:
EarlyStopping类:
import torch
class EarlyStopping:
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
"""
初始化
参数:
patience (int): 容忍验证指标无改善的最大周期数
verbose (bool): 是否打印早停信息
delta (float): 验证指标改善的最小阈值
path (str): 保存最佳模型的路径
trace_func (function): 打印信息的函数
"""
self.patience = patience
self.verbose = verbose
self.counter = 0 # 记录无明显改变的周期数
self.best_score = None
self.early_stop = False
self.val_loss_min = float('inf')
self.delta = delta
self.path = path
self.trace_func = trace_func
def __call__(self, val_loss, model):
"""
更新状态并判断是否早停
参数:
val_loss (float): 当前的验证集损失
model (torch.nn.Module): 模型
"""
score = -val_loss #为了让score最大化
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta: #无明显改善
self.counter += 1
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else: #有明显改善
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
"""
保存最佳模型
参数:
val_loss (float): 当前的验证集损失
model (torch.nn.Module): 模型
"""
if self.verbose:
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
调用:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__();
self.fc = nn.Linear(10, 1);
def forward(self, x):
return self.fc(x);
# 初始化早停机制
model = SimpleModel()
early_stopping = EarlyStopping(patience=5, verbose=True)
…… #省略其他内容
for epoch in range(30):
model.train()
for batch in train_loader:
optimizer.zero_grad()
output = model(batch)
loss = criterion(output, torch.ones_like(output))
loss.backward()
optimizer.step()
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
output = model(batch)
val_loss += criterion(output, torch.ones_like(output)).item()
val_loss /= len(val_loader) #计算验证集上的指标
early_stopping(val_loss, model) #call早停类进行验证
if early_stopping.early_stop:
print("Early stopping")
break
更多推荐



所有评论(0)