目的:防止模型过拟合。

模型在训练集上的性能通常会随着训练的进行而提升,但在验证集上可能会出现先上升后下降的情况,这意味着模型开始过拟合。

而早停的操作在于通过监控一个验证指标(如验证集损失、准确率等),当该指标在一定数量的训练周期内不再改善时,提前终止训练。

内容:

1. 设置patience参数:代表指标没有改善的最大训练周期数。

2. 设置delta参数:如果指标的改善在delta阈值以内,可以认为没有改善。

3. 监控指标:判断指标有没有朝着目标改善至少delta。

4. 保存最佳模型:每次有明显改善的时候,就记录model的参数,确保早停时模型能够重新载入最新的最佳的模型参数。

代码:

EarlyStopping类:

import torch

class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        初始化

        参数:
        patience (int): 容忍验证指标无改善的最大周期数
        verbose (bool): 是否打印早停信息
        delta (float): 验证指标改善的最小阈值
        path (str): 保存最佳模型的路径
        trace_func (function): 打印信息的函数
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0          # 记录无明显改变的周期数
        self.best_score = None
        self.early_stop = False  
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        """
        更新状态并判断是否早停

        参数:
        val_loss (float): 当前的验证集损失
        model (torch.nn.Module): 模型
        """
        score = -val_loss                    #为了让score最大化

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:        #无明显改善
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:                                            #有明显改善
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """
        保存最佳模型

        参数:
        val_loss (float): 当前的验证集损失
        model (torch.nn.Module): 模型
        """
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss
    

调用:

import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__();
        self.fc = nn.Linear(10, 1);

    def forward(self, x):
        return self.fc(x);

# 初始化早停机制
model = SimpleModel()
early_stopping = EarlyStopping(patience=5, verbose=True)

…… #省略其他内容


for epoch in range(30):
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()
        output = model(batch)
        loss = criterion(output, torch.ones_like(output))
        loss.backward()
        optimizer.step()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            output = model(batch)
            val_loss += criterion(output, torch.ones_like(output)).item()
    val_loss /= len(val_loader)  #计算验证集上的指标

    early_stopping(val_loss, model)    #call早停类进行验证
    if early_stopping.early_stop:
        print("Early stopping")
        break

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐