我们用 “抄近路保留原始信息” 的生活化比喻讲透核心逻辑,再逐步拆解它在Transformer中的具体处理流程、数学原理和关键作用。

一、小白版本:一句话懂残差连接

残差连接(Residual Connection)的本质就是 “把输入原封不动地加到输出上”,相当于给模型计算加了一条 “信息捷径”

“学做红烧肉” 的例子类比:

  • 假设你现在要做红烧肉(模型的输入x = 师傅给的“基础红烧肉配方”);
  • 你按照配方做了一遍,还加了自己的改良(模型的计算F(x) = 你的改良版红烧肉);
  • 最后你把 “基础配方”和“改良版”混在一起(残差连接:y=x+F(x)y = x + F(x)y=x+F(x)),这样就算改良翻车了,也不会丢掉基础配方的味道。

对应到Transformer中:

  • xxx = 每层的输入特征(比如词嵌入向量);
  • F(x)F(x)F(x) = 这一层的注意力或前馈网络计算结果;
  • yyy = 残差连接后的输出(原始信息+新学到的信息)。

二、稍微深入:Transformer中残差连接的具体处理步骤

Transformer的Encoder和Decoder每层都有两处残差连接,分别和注意力、前馈网络配合,再加上LayerNorm。我们分Post-LayerNorm(原版Transformer用的)和Pre-LayerNorm(深层模型常用)两种情况讲,核心都是“加和”,只是顺序不同。

1. 先明确一个关键前提:维度必须匹配

残差连接是向量的逐元素相加,要求 输入xxx的维度 和 计算输出F(x)F(x)F(x)的维度完全一致
在Transformer中,这个维度就是模型维度 dmodeld_{model}dmodel(比如512)。所以:

  • 多头注意力的输出维度是 dmodeld_{model}dmodel
  • 前馈网络的输入输出维度也是 dmodeld_{model}dmodel
    这样才能保证 x+F(x)x + F(x)x+F(x) 能顺利计算(不然维度不一样,加都加不了)。

2. Post-LayerNorm 模式(原版Transformer)

这是最直观的处理顺序,先计算,再残差,最后归一化,对应我们之前的“红烧肉”比喻:

Encoder每层的两处残差连接流程
# 第一步:多头自注意力 + 残差连接 + LayerNorm
输入x1 → 多头自注意力计算 F1(x1) → 残差加和:x1 + F1(x1) → LayerNorm → 输出x2

# 第二步:前馈网络 + 残差连接 + LayerNorm
输入x2 → 前馈网络计算 F2(x2) → 残差加和:x2 + F2(x2) → LayerNorm → 输出x3(传给下一层)

简单说:计算 → 加原始输入 → 标准化

3. Pre-LayerNorm 模式(深层模型常用)

顺序变了,先归一化,再计算,最后残差,适合深层模型(32层以上):

# 第一步:LayerNorm + 多头自注意力 + 残差连接
输入x1 → LayerNorm → 多头自注意力计算 F1(LN(x1)) → 残差加和:x1 + F1(LN(x1)) → 输出x2

# 第二步:LayerNorm + 前馈网络 + 残差连接
输入x2 → LayerNorm → 前馈网络计算 F2(LN(x2)) → 残差加和:x2 + F2(LN(x2)) → 输出x3

核心不变:不管顺序如何,残差连接都是把未经计算的原始输入加到计算后的输出上。

4. Decoder中的残差连接

Decoder每层有三处残差连接(比Encoder多一处“编码器-解码器注意力”),但处理逻辑完全一样:

输入x1 → 掩码自注意力 F1 → x1+F1 → LN → x2
x2 → 编码器-解码器注意力 F2 → x2+F2 → LN → x3
x3 → 前馈网络 F3 → x3+F3 → LN → x4

三、再深入:残差连接的核心作用(为什么没有它不行?)

残差连接是Transformer能堆叠深层模型(比如12层、24层)的关键技术,核心作用有两个,我们从“现象”讲到“数学原理”。

作用1:防止梯度消失,让深层模型能训练

这是残差连接最核心的作用。我们先搞懂什么是梯度消失

  • 模型训练的本质是“反向传播梯度,更新参数”;
  • 如果没有残差连接,梯度需要穿过每一层的计算函数 F(x)F(x)F(x),梯度是各层导数的乘积
  • 当模型层数很多时(比如24层),这些导数的乘积会趋近于 0 → 梯度消失,模型参数无法更新,训练直接“卡死”。

残差连接怎么解决这个问题?
从数学上看,残差连接的输出是 y=x+F(x)y = x + F(x)y=x+F(x)
反向传播时,梯度 ∂loss∂x\frac{\partial loss}{\partial x}xloss 可以分解为两部分:
∂loss∂x=∂loss∂y×(1+∂F(x)∂x)\frac{\partial loss}{\partial x} = \frac{\partial loss}{\partial y} \times \left(1 + \frac{\partial F(x)}{\partial x}\right)xloss=yloss×(1+xF(x))
重点看这个 1 —— 它意味着梯度可以通过**“捷径”直接传递**,不需要完全依赖 F(x)F(x)F(x) 的导数。

  • 即使 F(x)F(x)F(x) 的导数趋近于0,梯度至少还有一个“保底值” ∂loss∂y×1\frac{\partial loss}{\partial y} \times 1yloss×1
  • 这样梯度就不会消失,深层模型也能正常训练。

作用2:保留原始信息,让模型学“增量”而非“重构”

如果没有残差连接,模型每一层都需要从零开始学习所有特征,相当于“每次都要重新发明轮子”;
有了残差连接,模型只需要学习 “原始信息和目标信息的差异”(也就是残差 F(x)=y−xF(x) = y - xF(x)=yx),学习难度大大降低。

还是用红烧肉比喻:

  • 无残差:你需要完全靠自己创造红烧肉配方(难,容易跑偏);
  • 有残差:你只需要学习“怎么在基础配方上改良”(简单,不容易丢原味)。

对应到Transformer:

  • 底层网络学习基础的语法、语义特征;
  • 高层网络学习更抽象的特征(比如句子的逻辑关系);
    残差连接保证底层的基础特征不会被高层的计算“覆盖”。

作用3:提升模型的泛化能力,防止过拟合

残差连接相当于给模型加了一层“保护”:

  • 计算后的输出 F(x)F(x)F(x) 可能包含噪声(比如训练数据的干扰);
  • 加上原始输入 xxx 后,噪声会被“稀释”;
  • 模型学到的特征更鲁棒,在测试集上的表现更好。

四、关键细节:残差连接的“小坑”与解决

  1. 维度不匹配怎么办?
    如果某些层的计算输出维度和输入不一致(比如想改变模型维度),可以用一个线性投影层把输入维度转换成目标维度:
    y=Wx+F(x)y = Wx + F(x)y=Wx+F(x)
    其中 WWW 是投影矩阵,作用是把 xxx 的维度映射到 F(x)F(x)F(x) 的维度。

  2. 残差连接和LayerNorm的配合

    • Post-LN:残差加和后做LN,优点是直观,缺点是深层容易梯度消失;
    • Pre-LN:LN后做计算再残差,优点是深层训练稳定,缺点是归一化后的特征分布更激进;
      实际工程中,深层模型优先选Pre-LN + 残差。

五、终极总结

小白一句话总结

残差连接就是给模型加了一条“信息捷径”,把原始输入直接加到输出上,既防止深层训练崩溃,又不让模型丢了基础信息。

技术一句话总结

残差连接通过 y=x+F(x)y = x + F(x)y=x+F(x) 的逐元素加和,让梯度能直接反向传播(避免梯度消失),同时让模型学习特征增量,是Transformer实现深层堆叠的核心技术。


Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐