前言

元旦假期又被DeepSeek的论文刷屏了,号称DeepSeek提出了一种划时代的大模型网络结构,各大公众号争相报道,但又很难看懂其细节。想要真正了解这里面的细节,我们有以下几个问题需要回答。回答好了这几个问题也就了解了DeepSeek提出的工作的价值和意义:

  • 残差连接的恒等映射是什么,有什么优缺点?
  • 超连接(HC)是什么,解决了残差连接的什么问题?
  • 流形约束HC(mHC)是什么,解决了超连接什么问题?
  • mHC的价值和意义是什么?

本文不会过多涉及公式推导,尽量用浅显易懂的文字进行表达。
本文部分内容借助Gemini3-pro、Kimi-v2、豆包大模型完成。

残差连接的恒等映射是什么

残差连接(Residual Connections)最开始由大神何恺明在2016年提出: Deep Residual Learning for Image Recognition,今年刚好是该工作被提出的第十年,引用次数快达到30w,已经成为神经网络架构的基石,哪怕在如今的大模型时代,还是一直在沿用该方式。我们先简单了解一下残差连接:

  • 残差连接是什么?
  • 残差连接的提出解决了什么问题?
  • 残差连接有什么缺点?

残差连接是什么

在这里插入图片描述
深度神经网路是由一层一层的卷积block堆积而成,而残差连接就是在每个卷积block里面采用一个跳跃连接(skip connection)直接将输入和经过卷积之后的结果进行加和,公式表达如下:
y = F ( x ) + x y = F(x) + x y=F(x)+x
这里有两个关键点需要注意:

  1. 恒等映射(Identity Mapping):跳跃连接对输入的x不做任何变换,直接将输入与卷积之后的结果进行加和。尽管多层block累积,但残差连接的设计可以保证始终有一条“畅通无阻的公路”让起始输入直达最终输出。这种特性就叫恒等映射,也就是这个特性让模型可以堆叠更深的同时保持稳定的训练(梯度不消失)。
  2. 恒等权重(Identity weight):加和的权重不变,即输入的x的权重为1

残差连接的提出解决了什么问题

一句话总结残差连接被提出的目的:随着神经网络层数增加,网络无法被“轻易训练”(存在梯度消失等问题)。残差连接的提出使得更大、更深的神经网络可以被训练。当前的大模型可以做到几十亿,甚至上百亿的参数量,离不开残差连接的“功劳”。
具体来说,残差连接最初被提出主要是为了解决深度神经网络中的 梯度消失(Gradient Vanishing) 和 网络退化(Network Degradation) 问题。

  • 梯度消失: 随着网络层数加深,反向传播时的梯度会逐层相乘。如果梯度小于1,传到浅层时几乎为0,导致浅层参数无法更新。
  • 网络退化: 理论上深层网络不应比浅层网络差(至少可以恒等映射),但在实践中,随着层数增加,训练误差反而上升。残差连接通过引入“捷径”(shortcut),允许信息直接流向更深层,使得深层网络更容易训练。

残差连接有什么缺点

在现在的大模型中,残差连接并不是按照上面的方式被直接应用,它往往是结合着“归一化层”一起使用,按照归一化层的位置,当前主要有两种范式:Pre-Norm(归一化层在卷积(也可以是attention/FFN)之前) 和 Post-Norm(归一化层在卷积之后)。

Pre-Norm Post-Norm
y = x + F ( L a y e r N o r m ( x ) ) y = x + F(LayerNorm(x)) y=x+F(LayerNorm(x)) y = L a y e r N o r m ( x + F ( x ) ) y = LayerNorm(x + F(x)) y=LayerNorm(x+F(x))
优点:1. 训练非常稳定: 归一化位于卷积通路(attention/FFN)上,而残差连接(x)保持畅通。梯度可以直接沿着残差连接无损地传回第一层,极大地缓解了梯度消失问题。2. 无需复杂 Warm-up: 允许使用更大的学习率,且对超参数不那么敏感,适合训练超大规模深层网络。 优点:1. 潜力更高: 如果能训练成功,Post-Norm 模型通常能达到比 Pre-Norm 略好的最终性能(更低的 Loss),因为它保留了较大的梯度方差,有助于模型跳出局部最优。2. 缓解表示坍塌: Post-Norm 不容易出现深层特征趋同(Representation Collapse)的问题,每一层的特征差异性保持较好。
缺点:1. 表示坍塌 (Representation Collapse): 由于主干通路一直保留原始信息,随着层数加深,深层的输出和输入变得越来越像(余弦相似度接近 1)。2. 层效能降低: 深层网络的参数实际上贡献很小,导致模型虽然层数多,但有效容量可能受限。 缺点:1. 训练极不稳定: 这是最大的痛点。由于归一化在残差相加之后,梯度在反向传播时难以直接通过“捷径”传回浅层,容易导致梯度消失或梯度爆炸。2. 需要 Warm-up: 训练初期必须使用非常小心设计的 Learning Rate Warm-up 策略,否则模型很难收敛。
目前主流大模型(如 GPT-3, LLaMA, PaLM, 以及 OLMo)普遍采用的结构 最早在 Google 的原始 Transformer 论文 (“Attention Is All You Need”) 和 BERT 中使用
伪代码:x = Attn(Norm1(x)) + x; x = FFN(Norm2(x)) + x 伪代码:x = Norm1(Attn(x) + x); x = Norm2(FFN(x) + x)

这里就引出了超连接paper中提到的残差连接的缺点:

such as the seesaw effect between gradient vanishing and representation collapse.

即残差连接始终面临着梯度消失(gradient vanishing)与表示坍缩(representation collapse)之间的“跷跷板效应”。Pre-Norm训练稳定但会有特征表示坍缩问题;Post-Norm特征表示差异性更好但训练又不稳定。
超连接(Hyper Connections)的提出就是为了来解决这个问题!

超连接(HC)是什么

理论解释

原理

超连接(Hyper Connections) 是字节seed基础模型团队提出来的工作,该工作最早在2024年9月挂在arxiv上,后来中了 ICLR 2025
正如前面提到的那样,超连接的提出是为了解决残差连接中梯度消失和表征坍缩的“跷跷板效应”。本文根据该问题引出一个疑问:

Can neural networks autonomously learn the optimal strength of connections to improve performance? (神经网络能否自主学习最优的连接强度,以提升性能?)

答案当然是可以的,核心思想就是引入可学习的深度连接(depth-connections)和宽度连接(width-connections)。这里的关键词就是“可学习的”,通过引入一些可学习的的参数(论文中的 α \alpha α β \beta β)来提升模型的表征能力。可以通过公式对比来直接意识到HC的“特殊地方”:
残差连接:
y = F ( x ) + x y = F(x) + x y=F(x)+x
Hyper-Connection(简易公式):
y = β ∗ F ( α ∗ x ) + α ∗ x y = \beta * F(\alpha * x) + \alpha * x y=βF(αx)+αx
上面的HC公式并不是最终的公式,而是一个为了方便理解的化简版。实际论文中 α \alpha α β \beta β 不是一个可学习的向量(vector),而是一个矩阵(matrix)。为了实现矩阵的运算,文中对输入也做了对应的处理,即对输入 x x x n n n 份复制(即文中说的副本),实现宽度上的扩展。

理论优势

  • 深度上可以更灵活地融入特征
    • 在残差连接中经过Attention或者FFN之后的结果,一直以权重为1的方式与输入进行加和
    • HC中通过引用可学习的 β \beta β作为加和权重,实现了模型“自主”决策融合比例,比传统残差连接具有更强的表达能力
  • 能够同时建模多种深度连接方式,此外使得同一层内的隐藏状态可以相互交换信息
    • 通过对输入进行扩展,实现了同一层中存在不同的隐藏状态,增加了隐藏层的表征能力
    • 通过可学习的 α \alpha α 实现同一层不同隐藏状态之间的融合和交互

效果体现

  • 语言模型预训练(1B、7B、MoE):
    • 在相同训练 token 数量下,收敛速度提升 1.8 倍,下游任务准确率提升 6 个百分点(如 ARC-Challenge)。
    • 模型更稳定,训练过程中无 loss spike。
  • 视觉任务(ImageNet 分类与生成):
    • ViT 模型中,Top-1 准确率提升最多 2.69%。
    • DiT 图像生成中,性能接近参数量大 50% 的模型。

具体实现

n=2的情况

在这里插入图片描述
上图(b)实现了基于Pre-Norm的HC(n=2)的具体实现。我们也可以从公式上实现HC的表示。此外,当我们搞清楚了n=2的计算方式,也就能进一步搞明白论文中针对任意n的表达公式。
为了跟图中的参数对齐,后续将使用h来代替x,两者在本文中的含义是一致的,即都表示某一隐藏层的输入。
HC n=2 的具体实现:

  1. 在进入第一层之前对 h h h 进行扩展(即复制),图中扩展为两份,即 h 1 h_1 h1 h 2 h_2 h2
  2. 对每个隐藏状态进行Attention和FFN(这里用 F F F表示)计算:
    h 1 = β 1 ∗ F ( α 1 , 0 ∗ h 1 + α 2 , 0 ∗ h 2 ) + ( α 1 , 1 ∗ h 1 + α 2 , 1 ∗ h 2 ) h 2 = β 2 ∗ F ( α 1 , 0 ∗ h 1 + α 2 , 0 ∗ h 2 ) + ( α 1 , 2 ∗ h 1 + α 2 , 2 ∗ h 2 ) h_1 = \beta_1 * F(\alpha_{1,0} * h_1 + \alpha_{2,0} * h_2) + (\alpha_{1,1} * h_1 + \alpha_{2,1} * h_2) \\ h_2 = \beta_2 * F(\alpha_{1,0} * h_1 + \alpha_{2,0} * h_2) + (\alpha_{1,2} * h_1 + \alpha_{2,2} * h_2) h1=β1F(α1,0h1+α2,0h2)+(α1,1h1+α2,1h2)h2=β2F(α1,0h1+α2,0h2)+(α1,2h1+α2,2h2)
  3. 在经过多层block(一个block由一个attention和一个FFN组成)计算后,再将两个隐藏状态进行求和,得到最终的输出:
    h = h 1 + h 2 h = h_1 + h_2 h=h1+h2
    这里也可以写出n=2的HC矩阵表达:
    H C = ( 0 β 1 β 2 α 1 , 0 α 1 , 1 α 1 , 2 α 2 , 0 α 2 , 1 α 2 , 2 ) HC = \left( \begin{array}{ccc} 0 & \beta_1 & \beta_2 \\ \alpha_{1,0} & \alpha_{1,1} & \alpha_{1,2} \\ \alpha_{2,0} & \alpha_{2,1} & \alpha_{2,2} \\ \end{array} \right) HC= 0α1,0α2,0β1α1,1α2,1β2α1,2α2,2

通用表达

静态HC(static hyper-connections,SHC)

在这里插入图片描述
特别地, B ∈ R 1 × n ,   A m ∈ R n × 1 ,   A r ∈ R n × n B \in \mathbb{R}^{1 \times n}, \space A_m \in \mathbb{R}^{n \times 1}, \space A_r \in \mathbb{R}^{n \times n} BR1×n, AmRn×1, ArRn×n,公式中的 H H H 是多个隐藏状态 h h h 的组成。一般 h h h 的维度为 [ b a t c h ,   s e q ,   d i m ] [batch,\ seq,\ dim] [batch, seq, dim],经过扩展后 H H H 的维度变为 [ b a t c h ,   s e q ,   n ,   d i m ] [batch,\ seq,\ n,\ dim] [batch, seq, n, dim]
从上面也可以看出来,需要学习的就只有三个向量矩阵: B ∈ R 1 × n ,   A m ∈ R n × 1 ,   A r ∈ R n × n B \in \mathbb{R}^{1 \times n}, \space A_m \in \mathbb{R}^{n \times 1}, \space A_r \in \mathbb{R}^{n \times n} BR1×n, AmRn×1, ArRn×n,所以一个block中新增的参数量就只有:parm = 2 * (n + n + n*n) = 2n(n+2)。这里乘以2是因为一个block中包含一个attention和一个FFN。论文中推荐的n为4,所以一层block新增的参数量仅有48个。

动态HC(dynamic hyper-connections,DHC)

一句话总结动态HC与静态HC的差别:静态HC中每一层中的A/B参数与输入的H是完全无关的,即训练一旦完成,每层的A/B参数就会固定。
但这样就会存在一个缺点:网络无法根据每层输入的H情况来“动态”调整多个隐藏状态h之间的融合权重。 DHC 引入了输入依赖性 (Input-Dependent)。对于某些 Token,模型可能觉得“这一层的计算很重要”,于是加大深度连接的权重;对于另一些 Token,模型可能觉得“保持原样更好”,于是减小权重。这类似于给每个 Token 配备了一个智能的交通指挥员。
在这里插入图片描述
注意:

  • 通常取 H 的第一行(或某种聚合形式),经过 LayerNorm 后作为当前层的输入向量(公式10)
  • Wβ, Wm, Wr 是可学习的线性变换矩阵,即一个全连接层(Linear Layer)
  • sα, sβ 是初始值很小的缩放因子(如 0.01)
  • Tanh 的作用:论文特别强调(在 Visualization & Analysis 部分),Tanh 激活函数对于 DHC 的训练稳定性至关重要。如果去掉 Tanh,模型性能会显著下降甚至无法收敛。 它将权重限制在 (−1,1) 之间,防止信号爆炸。

论文中,对动态HC的参数计算也写的很详细,这里就直接贴一下:
在这里插入图片描述

总结

这里的数据都是以文中的OLMo-1B(n=4)为代表进行说明的

  • 参数量上:不管是静态HC还是动态HC,增加参数量均可忽略不计:针对一个1B的模型,静态HC增加的参数量约为0.00007%,而动态HC也仅为0.03%左右。
  • 计算复杂度上:无论静态还是动态超连接,引入的额外 FLOPs 均可忽略不计:针对一个1B的模型,静态HC增加的计算量约为0.127%,而动态HC也仅为0.2%左右。
  • 显存消耗上:由于HC的引入,显存上还是有明显的增加,这一点不像参数量和计算量:针对一个1B的模型,静态HC和动态HC增加的显存均约为26%左右。 关于显存明显增加的问题,这一点受到了ICLR审稿人的质疑,针对这一点seed该团队进一步提出了Frac-Connections(FC),来减少计算量和显存开销。

拓展阅读:Frac-Connections(FC)

原理

为了减少HC的计算量和显存开销,字节seed团队在25年3月份进一步提出了HC的优化版本: Frac-Connections: Fractional Extension of Hyper-Connections
一句话总结其特点:不像HC中对输入x进行n份复制来进行扩展,而是通过对输入x进行split:拆分成 m = 1 / n m = 1/ n m=1/n 份,这里的 n n n 为“扩展率”。可以认为FC是HC的一种拓展,HC中要求 n > 1 n > 1 n>1,把隐藏状态h复制n份。而FC定义了 0 < n < 1 0 < n < 1 0<n<1 的时候,这个时候是把隐藏状态h split成m份。
值得注意的是:FC是表征能力和计算量之间的一种权衡方法,FC虽然减少了计算量,但也降低了模型的表征能力。论文原话:

The similarity between adjacent hidden states in FC lies between that of HC and baseline (Pre-Norm), indicating that their representational capacity follows the order: HC > FC > Pre-Norm.

实现

在这里插入图片描述
从图(c)中可以看出,具体操作步骤为:

  1. 对输入h进行split操作,分为m份子向量
  2. 宽度上进行融合,执行cat后进行attention/FFN计算
  3. 将经过attention/FFN计算的结果进行对应的split,然后进行深度融合
  4. 最后将所有隐藏状态h进行拼接(cat)

优势

这里的数据都是以文中的OLMo-1B(n=4)为代表进行说明的

  • 参数量:动态FC增加的参数量为0.014%,比动态HC减少一半
  • 计算量:动态FC增加的计算量为0.044%,比动态HC减少约5倍
  • 显存消耗:文中没有明确统计。论文摘要中明确提到FC是为了reducing memory consumption,但这关键的一点却没有通过实验数据来证明,这可能也说明了FC也并没有达到显存显著节省的效果。

流形约束HC是什么

很多公众号都在讲DeepSeek提出来的流形约束超连接(mHC: Manifold-Constrained Hyper-Connections)的特性:

  • 方法:将残差连接矩阵变换为“双随机矩阵”,映射到一个Birkhoff 多胞形(Birkhoff polytope)上。
  • 目的:维持了“恒等变换特性”、节省了内存开销。
    要理解这里面的目的和做法,我们需要先深入理解一下字节seed提出的HC存在什么问题?

HC存在什么问题

图解

在这里插入图片描述
引用HC论文中的一张图,左边是传统的残差连接示意图,右边是HC示意图。通过对比可以发现:

  • 残差连接:不管神经网络block堆积多深,始终有一条畅通无阻的“高速公路”让输入h直达输出。这种特性就是“恒等变换”特性。这个特性解决了当网络变深后导致的梯度消失问题
  • 超连接(HC):这条“高速公路”上出现很多“阻碍”(即一些需要学习的参数 α \alpha α),这些阻碍在网络变深后会导致两种可能的问题:
    • 如果参数 α < 1 \alpha < 1 α<1,一旦累积多层后,这条“公路”就像断了一样,把输入逐渐置为了0,到网络深层就接收不到浅层的信息了。这就像移除了残差连接一样,会导致梯度消失的问题出现。
    • 如果参数 α > 1 \alpha > 1 α>1,一旦累积多层后,这条“公路”就像一个放大器,把输入逐渐放大了很多倍(deepseek论文里面有提到说峰值存在放大3000倍的情况),到网路深层接收到的浅层信息就“不保真”,数值巨大,会导致梯度爆炸的问题出现。

所以这里就引出了 mHC 论文的核心观点和做法:在HC里面,继续维持这条“公路”通畅,即参数 α \alpha α 一直保持在数值1 附近。在实际操作中,参数 α \alpha α(即HC论文中的 A r A_r Ar)是一个 n × n n \times n n×n 的矩阵,让矩阵维持这个特性的做法就是将 A r A_r Ar 通过某种变换,变成列和为1、行和为1的矩阵,这个矩阵就叫“双随机矩阵”(行和为1、列和为1是双随机矩阵的最大特性)。实现方法就是 Sinkhorn-Knopp 算法。

理论证明

在这里插入图片描述

注意mHC中的符号跟HC论文中是对不上的,对应关系是: H r e s = A r ,   H p r e = A m ,   H p o s t = B H^{res} = A_r, \ H^{pre} = A_m, \ H^{post} = B Hres=Ar, Hpre=Am, Hpost=B

mHC也用过消融实验证明, H r e s H^{res} Hres 矩阵对最终的模型效果影响最大,所以要优先处理 H r e s H^{res} Hres 矩阵。
这是mHC论文中对多层HC堆积的公式表达,红框中就是残差流上多个 H r e s H^{res} Hres 矩阵的累积结果。这个结果在网络多层累积后,会变得非常小或者非常大,导致梯度消失或者爆炸的问题出现。
在这里插入图片描述
在这里插入图片描述
mHC中也通过实验具体证明了这一点:

  1. 在一个27B模型的训练上,训练12k步后loss就飞了。
  2. 在网络深层(l=50), 累积的 H r e s H^{res} Hres 矩阵数值最高达到了3000,远远大于理想的1。mHC就是为了让这个累积结果一直维持在1附近。

mHC的具体实现

搞懂原理后,再来看mHC的具体实现就非常简单了。
论文中的公式表达:
在这里插入图片描述
解释:

  • 公式(7)就是HC论文中动态HC的实现,唯一不同的是输入RMSNorm的x,进行了flatten处理,让输入信息全部参与参数计算,而不是原文中选择第一个维度。
  • 公式(8)就是mHC的实现,在公式(7)的基础上仅仅加了一个矩阵变换,其中 σ \sigma σ 代表是sigmoid函数,对于相对不重要的参数矩阵,仅进行简单的sigmoid变换。而对于重要的 H r e s H^{res} Hres 矩阵,做了 Sinkhorn-Knopp 变换 。

计算效率优化

DeepSeek的工作除了对 H r e s H^{res} Hres 矩阵的变换外,他们还通过一些列工程的手段来提升了模型的计算效率,同时解决了HC和FC中存在的显存开销问题,这一点也是该工作的一个重大意义。
由于工程优化不是本文的重点,这里就不再进行具体的展开。这里引用字节提出的“Hyper-Connections”,被DeepSeek救活了?公众号比较简洁的总结:
对于 Transformer 架构,HC 带来的额外 I/O 开销约为 ( 5 n + 1 ) C (5n+1)C (5n+1)C。为了解决这个问题,DeepSeek 实施了三项系统优化:

  • Kernel Fusion:使用 TileLang 手写算子,将 RMSNorm、MatMul 和加法操作深度融合,减少显存读写次数。
  • Recomputing(重计算):在前向传播中丢弃中间激活值,反向传播时重新计算。通过优化 Block Size ( L r ≈ n L / ( n + 2 ) L_r \approx \sqrt{nL/(n+2)} LrnL/(n+2) ),降低显存占用。
  • DualPipe Overlap:在流水线并行中,将 mHC 带来的额外计算与通信时间重叠掩盖。

最终,在 n=4 的配置下,mHC 在大规模训练中的额外时间开销被压缩至 6.7%,使其具备了工程可用性。

思考

在这里插入图片描述
虽然mHC通过对 H r e s H^{res} Hres 矩阵进行约束,维持了恒等变换的特性,但这似乎也会弱化模型的表征能力,即弱化不同隐藏状态之间交互的作用。HC的论文中也指出如果 H r e s H^{res} Hres 矩阵变为了一个单位矩阵,但不同的隐藏状态h之间都没有任何交互了,这就违背了HC的设计初衷。

从mHC论文中的图8中,可以发现在深层的时候(l=30/60)的时候, H r e s H^{res} Hres 矩阵已经非常像一个单位矩阵了,所以在深层的时候,不同h之间的交互“似乎”已经变得非常的弱了。

所以,个人觉得:应该在限制 H r e s H^{res} Hres 矩阵满足双随机矩阵特性的同时,也要让其不要直接堕化为一个简单的单位矩阵。 也就是在满足恒等变换的同时也不要弱化不同隐藏状态h的交互。这里可以通过正则化等操作来对这些可学习的参数进行一定的限制。

mHC的价值和意义是什么

价值

  1. 理论意义:mHC的提出使得HC可以应用在更深、更大的模型上,这在当前的大模型时代是一个非常刚需的要求。简单来说,残差连接的提出使得卷积block可以堆叠的更深,而mHC的提出使得带HC的block可以堆叠的更深。
  2. 工程意义:从工程的角度解决了HC/FC一直被诟病的显存开销的问题,可以推动HC的研究。
  3. HC变革了模型架构设计:之前的大模型为了增强表征能力,只能通过加大网络深度,或者每个token使用更大的维度(即dim),但这两个方法都会显著的导致模型的计算复杂度被提升。HC的提出,引出了一个新的维度,即残差流宽度(n),它的优势在于扩大n,不会带来显著的参数量和计算复杂度的提升。
  4. mHC会不会引发新的一轮模型架构革新?这是mHC论文在摘要中有明确提到的一个期望。目前的mHC只是偏理论的一些证明,可能很快DeepSeek就会放出mHC在实际应用上的价值,可能就会应用在他们的下一版本发布的模型上。

未来的研究

  1. 字节seed对HC的探索仅在7B的模型上,deepseek提出的mHC在27B的模型上验证了可能,那在更大的模型上,mHC是否还同样有效呢?
  2. 残差流的宽度探索(即n参数),目前都是使用默认的设置4,那更大的n会有什么问题呢?效果会更上一层么?
  3. HC架构应用在更多的场景模型上,当前还是主要在语言基础模型上,图像理解、生成等大模型上表现会如何?

参考

  1. 2026 第一枪:字节提出的超连接,被 DeepSeek 救活了
  2. 字节提出的“Hyper-Connections”,被DeepSeek救活了?
  3. 刚刚,梁文锋署名,DeepSeek元旦新论文要开启架构新篇章
  4. 详细解读DeepSeek新年的第一篇论文,他们就是这个时代的真神。
  5. 如何理解 DeepSeek 最新提出的 mHC 架构?
  6. Deep Residual Learning for Image Recognition
  7. Hyper-Connections
  8. Frac-Connections: Fractional Extension of Hyper-Connections
  9. mHC: Manifold-Constrained Hyper-Connections

附录

Pre-Norm Block 实现

class PreNormBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.attn = MultiHeadAttention(d_model)
        self.ffn = FeedForward(d_model)
        
        # 定义两个归一化层,分别用于 Attn 和 FF 之前
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # 1. Self-Attention 子层
        # 先 Norm,再进子层,最后加残差
        residual = x
        x = self.norm1(x)
        x = self.attn(x)
        x = x + residual  # 残差连接

        # 2. Feed-Forward 子层
        # 同样:先 Norm,再进子层,最后加残差
        residual = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = x + residual  # 残差连接
        
        return x

Post-Norm Block 实现

class PostNormBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.attn = MultiHeadAttention(d_model)
        self.ffn = FeedForward(d_model)
        
        # 定义两个归一化层,分别用于 Attn 和 FF 之后
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # 1. Self-Attention 子层
        # 先进子层,加残差,最后 Norm
        residual = x
        x = self.attn(x)
        x = residual + x  # 残差连接
        x = self.norm1(x) # Post-Norm

        # 2. Feed-Forward 子层
        # 同样:先进子层,加残差,最后 Norm
        residual = x
        x = self.ffn(x)
        x = residual + x  # 残差连接
        x = self.norm2(x) # Post-Norm
        
        return x

一个完整的Hyper-Connection网络demo实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # 简化版,实际应用通常使用 nn.MultiheadAttention
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=4, batch_first=True)

    def forward(self, x):
        # x shape: [batch_size, seq_len, d_model]
        output, _ = self.attn(x, x, x)
        return output

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=None):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.net(x)

class HyperConnection(nn.Module):
    def __init__(self, dim, rate, layer_id, dynamic, device=None):
        super(HyperConnection, self).__init__()

        self.rate = rate            # 对应论文中的扩展率n
        self.layer_id = layer_id    # 网络第几层
        self.dynamic = dynamic      # 是否使用动态HC

        self.static_beta = nn.Parameter(torch.ones((rate,), device=device))    # 对应论文中的矩阵B [n]

        init_alpha0 = torch.zeros((rate, 1), device=device)                    # 对应论文中的矩阵Am [n, 1]
        init_alpha0[layer_id % rate, 0] = 1.0                                  # 对矩阵Am初始化
        self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye((rate), device=device)], dim=1))  # 对应论文中的矩阵Am+Ar [n, n+1]

        if self.dynamic:  # 动态HC的初始化,对应论文中的公式(10)-(13),注意这里的alpha是对Am和Ar矩阵拼接之后的表示
            self.dynamic_alpha_fn = nn.Parameter(torch.zeros((dim, rate+1), device=device))
            self.dynamic_alpha_scale = nn.Parameter(torch.ones(1, device=device) * 0.01)
            self.dynamic_beta_fn = nn.Parameter(torch.zeros((dim, ), device=device))
            self.dynamic_beta_scale = nn.Parameter(torch.ones(1, device=device) * 0.01)
            self.layer_norm = nn.LayerNorm(dim)

    # 宽度连接
    def width_connection(self, h):
        if self.dynamic:
            norm_h = self.layer_norm(h)
        
        if self.dynamic:
            wc_weight = norm_h @ self.dynamic_alpha_fn
            wc_weight = F.tanh(wc_weight)
            dynamic_alpha = wc_weight * self.dynamic_alpha_scale
            alpha = dynamic_alpha + self.static_alpha[None, None, ...]
        else:
            alpha = self.static_alpha[None, None, ...]
        
        if self.dynamic:
            dc_weight = norm_h @ self.dynamic_beta_fn
            dc_weight = F.tanh(dc_weight)
            dynamic_beta = dc_weight * self.dynamic_beta_scale
            beta = dynamic_beta + self.static_beta[None, None, ...]
        else:
            beta = self.static_beta[None, None, ...]
        
        mix_h = alpha.transpose(-1, -2) @ h

        return mix_h, beta
    
    # 深度连接
    def depth_connection(self, mix_h, h_o, beta):
        h = torch.einsum("blh,bln->blnh", h_o, beta) + mix_h[..., 1:, :]
        return h

class HyperConnectionBlock(nn.Module):
    def __init__(self, dim, rate, layer_id, dynamic, device=None):
        super(HyperConnectionBlock, self).__init__()

        self.atten_hyper_connection = HyperConnection(dim=dim, rate=rate, layer_id=layer_id, dynamic=dynamic, device=device)
        self.ffn_hyper_connection = HyperConnection(dim=dim, rate=rate, layer_id=layer_id, dynamic=dynamic, device=device)

        # 定义两个归一化层,分别用于 Attn 和 FF 之前
        self.attn_norm = nn.LayerNorm(dim)
        self.ffn_norm = nn.LayerNorm(dim)

        # 定义两个子模块,分别用于 Attention 和 FFN
        self.self_attention = MultiHeadAttention(d_model=dim)
        self.feed_forward = FeedForward(d_model=dim)

        # 定义一个 Dropout 层
        self.dropout = nn.Dropout(0.1)

    def forward(self, h):
        # Attention Block
        mix_h, beta = self.atten_hyper_connection.width_connection(h)
        h = self.attn_norm(mix_h[...,0,:])
        h = self.self_attention(h)
        h = self.atten_hyper_connection.depth_connection(mix_h, self.dropout(h), beta)

        # FFN Block
        mix_h, beta = self.ffn_hyper_connection.width_connection(h)
        h = self.ffn_norm(mix_h[...,0,:])
        h = self.feed_forward(h)
        h = self.ffn_hyper_connection.depth_connection(mix_h, self.dropout(h), beta)

        return h

class HyperConnectionModel(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_hyper=4, dynamic=True, device='cpu'):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.n_hyper = n_hyper
        self.d_model = d_model
        
        # 堆叠 HC Blocks
        self.layers = nn.ModuleList([
            HyperConnectionBlock(d_model, n_hyper, i, dynamic=dynamic, device=device) for i in range(n_layers)
        ])
        
        self.final_norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):
        # 1. Embedding
        x = self.embedding(input_ids) # [batch, seq, d_model]
        
        # 2. 初始化 Hyper Hidden States (Algorithm 1: Initialization)
        # 将 x 复制 n 次形成初始矩阵 H
        batch_size, seq_len, _ = x.shape
        H = x.unsqueeze(2).expand(-1, -1, self.n_hyper, -1).clone() # [b, s, n, d]
        
        # 3. 通过所有层
        for layer in self.layers:
            H = layer(H)
            
        # 4. 最终输出聚合 (Algorithm 1: Final Output)
        # 论文提到: Sum rows of H
        H_final = H.sum(dim=2) # Sum over n dimension -> [b, s, d]
        
        # 5. 最终分类
        out = self.final_norm(H_final)
        logits = self.head(out)
        
        return logits

# --- 测试代码 ---
def test_hc_implementation():
    batch_size = 2
    seq_len = 10
    vocab_size = 100
    d_model = 32
    n_layers = 2 # 网络层数
    n_hyper = 4 # 论文推荐值
    dynamic = True # 是否使用动态HC
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    model = HyperConnectionModel(vocab_size, d_model, n_layers=n_layers, n_hyper=n_hyper, dynamic=dynamic, device=device).to(device)
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)
    
    output = model(input_ids)
    
    print(f"Input shape: {input_ids.shape}")
    print(f"Hyper-Connection Expansion Rate (n): {n_hyper}")
    print(f"Dynamic HC: {dynamic}")
    print(f"device: {device}")
    print(f"Output Logits shape: {output.shape}")
    
    # 验证是否包含 NaN (检查 tanh 稳定性)
    if torch.isnan(output).any():
        print("Error: Output contains NaNs!")
    else:
        print("Pass: Forward pass successful.")

if __name__ == "__main__":
    test_hc_implementation()

Sinkhorn-Knopp 算法

豆包回答:https://www.doubao.com/thread/wc33396a0bed92aa9

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐