残差连接(Residual Connection)是Transformer中的一个关键设计,用于解决深层网络训练时的梯度消失问题,同时帮助模型保留原始输入信息。它的操作非常简单,但效果显著。以下是通俗易懂的解释:


1. 残差连接的核心操作

一句话总结
把当前层的输入直接加到当前层的输出上,形成“输入 + 输出”的短路路径。

数学公式
输出=输入 x + 当前层的变换(x) \text{输出} = \text{输入} \ x \ + \ \text{当前层的变换}(x) 输出=输入 x + 当前层的变换(x)
(其中“当前层的变换”可能是自注意力、交叉注意力或前馈网络)


2. 具体步骤(以解码器的自注意力层为例)

假设输入是一个向量 ( x )(已包含词嵌入和位置编码),经过自注意力层后的输出为 SelfAttn(x)\text{SelfAttn}(x)SelfAttn(x)

  1. 保留原始输入:将输入 ( x ) 复制一份。
  2. 叠加变换结果:将自注意力的输出 SelfAttn(x)\text{SelfAttn}(x)SelfAttn(x)与原始输入 ( x ) 逐元素相加。
    残差输出=x+SelfAttn(x) \text{残差输出} = x + \text{SelfAttn}(x) 残差输出=x+SelfAttn(x)
  3. 层归一化:对相加后的结果做归一化(LayerNorm)。
    最终输出=LayerNorm(x+SelfAttn(x)) \text{最终输出} = \text{LayerNorm}(x + \text{SelfAttn}(x)) 最终输出=LayerNorm(x+SelfAttn(x))

3. 直观类比

想象你正在修改一篇文章:

  • 原始输入(x):初稿的文本。
  • 当前层的变换(SelfAttn(x)):你写的修改建议(比如添加一些描述)。
  • 残差连接:把修改建议直接“贴”到初稿上(初稿 + 修改),而不是完全重写。
  • 层归一化:调整合并后的格式,使其更规范。

关键点:无论你修改多少遍,初稿的内容始终保留,避免彻底丢失原始信息。


4. 为什么需要残差连接?

解决的问题
  • 梯度消失:深层网络中,反向传播时梯度可能逐层衰减,导致浅层参数无法更新。残差连接提供了直通路径,让梯度能直接回传。
  • 信息丢失:传统网络可能过度修改输入,残差连接强制模型只学习“需要补充或调整的部分”(即残差)。
对比实验
  • 不带残差连接:Transformer在6层以上时,训练损失难以收敛。
  • 带残差连接:即使堆叠100层,模型仍能稳定训练(如GPT-3)。

5. 代码示例(PyTorch风格)

import torch
import torch.nn as nn

class DecoderLayer(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model)  # 自注意力层
        self.norm = nn.LayerNorm(d_model)           # 层归一化

    def forward(self, x):
        # 残差连接:输入x + 自注意力输出
        residual = x
        x = self.self_attn(x)
        x = self.norm(residual + x)  # 先相加,再归一化
        return x

6. 残差连接的变体

  • 经典残差输出 = 输入 + 变换(输入)(Transformer采用)。
  • 预激活残差:先归一化再变换(如ResNet v2)。
  • 自适应残差:动态调整残差权重(如门控机制)。

7. 总结

  • 操作:输入与输出直接相加。
  • 目的:保留原始信息,缓解梯度消失,稳定深层训练。
  • 效果:让Transformer可以堆叠数十层甚至上百层,仍能高效学习。

类比记忆
残差连接就像“写论文时保留初稿,每次修改只添加批注”——既避免推倒重来,又能逐步完善。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐