【AI课程领学】第五课 · 循环神经网络(课时4) 训练技巧:梯度截断、BPTT、Padding、Teacher Forcing 与工程实践

【AI课程领学】第五课 · 循环神经网络(课时4) 训练技巧:梯度截断、BPTT、Padding、Teacher Forcing 与工程实践



欢迎铁子们点赞、关注、收藏!
祝大家逢考必过!逢投必中!上岸上岸上岸!upupup

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文。详细信息可扫描博文下方二维码 “学术会议小灵通”或参考学术信息专栏:https://ais.cn/u/mmmiUz
详细免费的AI课程可在这里获取→www.lab4ai.cn


前言

RNN 真正难的其实不是“结构本身”,而是 训练:

  • 序列长 → 梯度消失/爆炸
  • 样本长短不一 → 需要 padding、mask
  • 文本生成类任务 → 训练和推理阶段不一致,需要 teacher forcing 等技巧
  • 工程上还要考虑 batch 组织、效率与 GPU 占用

本篇我们分模块讲:

  1. BPTT(通过时间的反向传播)与梯度截断
  2. 可变长序列的 padding 与 pack_padded_sequence
  3. Teacher Forcing 与 scheduled sampling
  4. RNN 训练时的一些工程实践建议(学习率、初始化等)

一、BPTT 与梯度截断

1.1 BPTT 是什么?

  • 对于 RNN,损失不仅依赖最后一个时间步,而是整个序列:
    在这里插入图片描述
    参数同时作用于所有时间步,因此反向传播需要沿时间方向展开:
  • 这就是 Backpropagation Through Time(BPTT)。

当序列很长时,梯度会沿着时间维不断被链式法则“乘来乘去”,非常容易:

  • 梯度消失(训练很慢甚至无效)
  • 梯度爆炸(参数更新不稳定)

1.2 梯度截断(Gradient Clipping)

  • 解决梯度爆炸的常用做法:在每一步更新前,限制梯度的范数不超过某个阈值。

PyTorch 中一行搞定:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

  • 放在 loss.backward()optimizer.step() 之间即可。

二、可变长序列:Padding 与 pack_padded_sequence

真实数据中,序列长度通常不一样:

  • 文本:有的句子 5 个词,有的 50 个词
  • 时间序列:观测长度不一致
  • 遥感时序:有缺测/遮挡

2.1 Padding:统一长度

做法:

  • 找到 batch 中最长序列长度 max_len
  • 其它序列在结尾补 0,使形状一致

但补 0 会引入无效时间步。我们希望在计算时忽略这些 padding 部分,否则会干扰梯度与隐藏状态。

2.2 PyTorch 中的打包序列:pack_padded_sequence

PyTorch 提供了一套专门工具处理变长序列:

  • nn.utils.rnn.pack_padded_sequence
  • nn.utils.rnn.pad_packed_sequence

基本流程:

  1. 准备 padded_sequences(统一长度)以及真实长度列表 lengths
  2. pack_padded_sequence 打包
  3. 把 pack 结果喂给 RNN/LSTM/GRU
  4. pad_packed_sequence 恢复(如果需要)

示例:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# 假设 x_padded: (batch, max_len, input_size), lengths: (batch,)
# 注意:pack_padded_sequence 要求序列按长度降序排序,或者设置 enforce_sorted=False
x_padded = ...
lengths = ...

x_packed = pack_padded_sequence(
    x_padded,
    lengths=lengths.cpu(),
    batch_first=True,
    enforce_sorted=False
)

lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
packed_out, (h_n, c_n) = lstm(x_packed)

out, out_lengths = pad_packed_sequence(packed_out, batch_first=True)

  • 优势:RNN 不会在 padding 部分多算无用步骤,隐藏状态不会被 padding“污染”,效率也更高。

三、Teacher Forcing:解决训练/推理阶段分布不一致

在 Seq2Seq 或生成式任务中(机器翻译、文本生成等),Decoder 在每一步的输入通常是:

  • 训练时:真实标签(ground truth)
  • 推理时:模型前一步预测的结果

这会导致一个问题:

  • 模型在训练时总是“看到正确上下文”,而推理时却要面对自己“错误的历史预测”,产生分布不匹配,导致错误不断累积。

为缓解这一点,引入 Teacher Forcing:

  • 训练时某些时间步用真实标签作为 Decoder 输入
  • 其他时间步用模型预测的结果
  • 使用一个概率 p_tf(teacher forcing ratio)控制两者之间的权衡

伪代码示意:

import random

def decode_with_teacher_forcing(decoder, encoder_out, y_targets, p_tf=0.5):
    # y_targets: 真实输出序列
    # 通常会有一个 <bos> 作为起始
    inputs = y_targets[:, 0]  # <bos>
    hidden = decoder.init_hidden_from_encoder(encoder_out)

    outputs = []
    for t in range(1, y_targets.size(1)):
        out, hidden = decoder(inputs, hidden)
        outputs.append(out)

        use_teacher = random.random() < p_tf
        if use_teacher:
            inputs = y_targets[:, t]          # 用真实 token
        else:
            inputs = out.argmax(dim=-1)       # 用模型预测值

    return torch.stack(outputs, dim=1)

  • 进阶:Scheduled Sampling 会随着训练进程逐步降低 p_tf,逐渐让模型更适应“自回归”场景。

四、RNN 训练的工程实践建议

4.1 输入归一化 & 初始隐藏状态

  • 对于连续特征(如时间序列),建议先做标准化(mean-std)
  • 初始隐藏状态通常设为 0;若任务需要,也可以学习一个全局初始向量

PyTorch 中如不指定,nn.LSTM 默认 h_0c_0 都是 0。

4.2 学习率与优化器

  • Adam/AdamW 是 RNN 的常用起点(如 lr=1e-3
  • 对长序列/梯度不稳定任务,多尝试 小一点学习率 + 梯度裁剪

4.3 Dropout 与正则化

  • nn.LSTMnn.GRU 内部支持层间 dropout(非最后一层之间)
  • 序列任务中适当增加输入层、输出层的 dropout 可以提升泛化

例子:

lstm = nn.LSTM(
    input_size=feature_dim,
    hidden_size=hidden_dim,
    num_layers=2,
    dropout=0.2,          # 仅在多层之间有效
    batch_first=True
)

4.4 序列截断训练(Truncated BPTT)

对于极长序列(比如上千步),完整 BPTT 计算与显存消耗都很大。

Truncated BPTT 思想:

  • 将长序列分成若干小片段(如 100 步一段)
  • 前一段的最终隐藏状态作为后一段的初始状态
  • 每一段单独做一次反向传播

简化示意:

hidden = None
for i in range(0, seq_len, chunk_len):
    x_chunk = x[:, i:i+chunk_len]    # (batch, chunk_len, input_size)
    out, hidden = lstm(x_chunk, hidden)
    loss = ...
    optimizer.zero_grad()
    loss.backward()
    # 截断反向传播的“图”:
    hidden = (hidden[0].detach(), hidden[1].detach())
    optimizer.step()

  • 这样就不会因为一次 BPTT 展开太久而导致显存过大。

五、综合示例:可变长序列 + LSTM + 梯度截断

  • 下面给一个稍完整的例子,对“可变长序列二分类”进行训练:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class VarLenLSTMClassifier(nn.Module):
    def __init__(self, input_size=10, hidden_size=64, num_layers=1, num_classes=2):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x_padded, lengths):
        # x_padded: (batch, max_len, input_size)
        # lengths: (batch,) 实际长度
        packed = pack_padded_sequence(
            x_padded, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        packed_out, (h_n, c_n) = self.lstm(packed)
        # 这里我们直接用最后一层的最终 hidden 作为表示
        last_hidden = h_n[-1]        # (batch, hidden_size)
        logits = self.fc(last_hidden)
        return logits

# 假数据
batch = 4
max_len = 7
input_size = 10

# 长度不同的序列
lengths = torch.tensor([7, 5, 6, 3])
x_padded = torch.randn(batch, max_len, input_size)

model = VarLenLSTMClassifier(input_size=input_size)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
y = torch.randint(0, 2, (batch,))

logits = model(x_padded, lengths)
loss = F.cross_entropy(logits, y)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()

print("loss:", loss.item())

这个例子把:

  • 可变长序列
  • pack_padded_sequence
  • LSTM 分类
  • 梯度截断

整合到了一起,是实际项目中相当常用的模式。

六、小结

本篇我们重点解决了“RNN 实战中最常碰到的坑”:

  • BPTT 与梯度消失/爆炸 → 用梯度截断、合适的结构(LSTM/GRU)
  • 可变长序列 → padding + pack/pad 工具
  • 生成任务的训练/推理不一致 → Teacher Forcing + Scheduled Sampling
  • 极长序列 → Truncated BPTT
  • 工程实践 → 归一化、优化器配置、Dropout、初始化等
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐