深入理解流形约束超连接(Manifold-Constrained Hyper-Connections)
这里的数据都是以文中的OLMo-1B(n=4)为代表进行说明的参数量上:不管是静态HC还是动态HC,增加参数量均可忽略不计:针对一个1B的模型,静态HC增加的参数量约为0.00007%,而动态HC也仅为0.03%左右。计算复杂度上:无论静态还是动态超连接,引入的额外 FLOPs 均可忽略不计:针对一个1B的模型,静态HC增加的计算量约为0.127%,而动态HC也仅为0.2%左右。针对一个1B的模型
前言
元旦假期又被DeepSeek的论文刷屏了,号称DeepSeek提出了一种划时代的大模型网络结构,各大公众号争相报道,但又很难看懂其细节。想要真正了解这里面的细节,我们有以下几个问题需要回答。回答好了这几个问题也就了解了DeepSeek提出的工作的价值和意义:
- 残差连接的恒等映射是什么,有什么优缺点?
- 超连接(HC)是什么,解决了残差连接的什么问题?
- 流形约束HC(mHC)是什么,解决了超连接什么问题?
- mHC的价值和意义是什么?
本文不会过多涉及公式推导,尽量用浅显易懂的文字进行表达。
本文部分内容借助Gemini3-pro、Kimi-v2、豆包大模型完成。
残差连接的恒等映射是什么
残差连接(Residual Connections)最开始由大神何恺明在2016年提出: Deep Residual Learning for Image Recognition,今年刚好是该工作被提出的第十年,引用次数快达到30w,已经成为神经网络架构的基石,哪怕在如今的大模型时代,还是一直在沿用该方式。我们先简单了解一下残差连接:
- 残差连接是什么?
- 残差连接的提出解决了什么问题?
- 残差连接有什么缺点?
残差连接是什么

深度神经网路是由一层一层的卷积block堆积而成,而残差连接就是在每个卷积block里面采用一个跳跃连接(skip connection)直接将输入和经过卷积之后的结果进行加和,公式表达如下:
y = F ( x ) + x y = F(x) + x y=F(x)+x
这里有两个关键点需要注意:
- 恒等映射(Identity Mapping):跳跃连接对输入的x不做任何变换,直接将输入与卷积之后的结果进行加和。尽管多层block累积,但残差连接的设计可以保证始终有一条“畅通无阻的公路”让起始输入直达最终输出。这种特性就叫恒等映射,也就是这个特性让模型可以堆叠更深的同时保持稳定的训练(梯度不消失)。
- 恒等权重(Identity weight):加和的权重不变,即输入的x的权重为1
残差连接的提出解决了什么问题
一句话总结残差连接被提出的目的:随着神经网络层数增加,网络无法被“轻易训练”(存在梯度消失等问题)。残差连接的提出使得更大、更深的神经网络可以被训练。当前的大模型可以做到几十亿,甚至上百亿的参数量,离不开残差连接的“功劳”。
具体来说,残差连接最初被提出主要是为了解决深度神经网络中的 梯度消失(Gradient Vanishing) 和 网络退化(Network Degradation) 问题。
- 梯度消失: 随着网络层数加深,反向传播时的梯度会逐层相乘。如果梯度小于1,传到浅层时几乎为0,导致浅层参数无法更新。
- 网络退化: 理论上深层网络不应比浅层网络差(至少可以恒等映射),但在实践中,随着层数增加,训练误差反而上升。残差连接通过引入“捷径”(shortcut),允许信息直接流向更深层,使得深层网络更容易训练。
残差连接有什么缺点
在现在的大模型中,残差连接并不是按照上面的方式被直接应用,它往往是结合着“归一化层”一起使用,按照归一化层的位置,当前主要有两种范式:Pre-Norm(归一化层在卷积(也可以是attention/FFN)之前) 和 Post-Norm(归一化层在卷积之后)。
| Pre-Norm | Post-Norm |
|---|---|
| y = x + F ( L a y e r N o r m ( x ) ) y = x + F(LayerNorm(x)) y=x+F(LayerNorm(x)) | y = L a y e r N o r m ( x + F ( x ) ) y = LayerNorm(x + F(x)) y=LayerNorm(x+F(x)) |
| 优点:1. 训练非常稳定: 归一化位于卷积通路(attention/FFN)上,而残差连接(x)保持畅通。梯度可以直接沿着残差连接无损地传回第一层,极大地缓解了梯度消失问题。2. 无需复杂 Warm-up: 允许使用更大的学习率,且对超参数不那么敏感,适合训练超大规模深层网络。 | 优点:1. 潜力更高: 如果能训练成功,Post-Norm 模型通常能达到比 Pre-Norm 略好的最终性能(更低的 Loss),因为它保留了较大的梯度方差,有助于模型跳出局部最优。2. 缓解表示坍塌: Post-Norm 不容易出现深层特征趋同(Representation Collapse)的问题,每一层的特征差异性保持较好。 |
| 缺点:1. 表示坍塌 (Representation Collapse): 由于主干通路一直保留原始信息,随着层数加深,深层的输出和输入变得越来越像(余弦相似度接近 1)。2. 层效能降低: 深层网络的参数实际上贡献很小,导致模型虽然层数多,但有效容量可能受限。 | 缺点:1. 训练极不稳定: 这是最大的痛点。由于归一化在残差相加之后,梯度在反向传播时难以直接通过“捷径”传回浅层,容易导致梯度消失或梯度爆炸。2. 需要 Warm-up: 训练初期必须使用非常小心设计的 Learning Rate Warm-up 策略,否则模型很难收敛。 |
| 目前主流大模型(如 GPT-3, LLaMA, PaLM, 以及 OLMo)普遍采用的结构 | 最早在 Google 的原始 Transformer 论文 (“Attention Is All You Need”) 和 BERT 中使用 |
| 伪代码:x = Attn(Norm1(x)) + x; x = FFN(Norm2(x)) + x | 伪代码:x = Norm1(Attn(x) + x); x = Norm2(FFN(x) + x) |
这里就引出了超连接paper中提到的残差连接的缺点:
such as the seesaw effect between gradient vanishing and representation collapse.
即残差连接始终面临着梯度消失(gradient vanishing)与表示坍缩(representation collapse)之间的“跷跷板效应”。Pre-Norm训练稳定但会有特征表示坍缩问题;Post-Norm特征表示差异性更好但训练又不稳定。
超连接(Hyper Connections)的提出就是为了来解决这个问题!
超连接(HC)是什么
理论解释
原理
超连接(Hyper Connections) 是字节seed基础模型团队提出来的工作,该工作最早在2024年9月挂在arxiv上,后来中了 ICLR 2025。
正如前面提到的那样,超连接的提出是为了解决残差连接中梯度消失和表征坍缩的“跷跷板效应”。本文根据该问题引出一个疑问:
Can neural networks autonomously learn the optimal strength of connections to improve performance? (神经网络能否自主学习最优的连接强度,以提升性能?)
答案当然是可以的,核心思想就是引入可学习的深度连接(depth-connections)和宽度连接(width-connections)。这里的关键词就是“可学习的”,通过引入一些可学习的的参数(论文中的 α \alpha α和 β \beta β)来提升模型的表征能力。可以通过公式对比来直接意识到HC的“特殊地方”:
残差连接:
y = F ( x ) + x y = F(x) + x y=F(x)+x
Hyper-Connection(简易公式):
y = β ∗ F ( α ∗ x ) + α ∗ x y = \beta * F(\alpha * x) + \alpha * x y=β∗F(α∗x)+α∗x
上面的HC公式并不是最终的公式,而是一个为了方便理解的化简版。实际论文中 α \alpha α 和 β \beta β 不是一个可学习的向量(vector),而是一个矩阵(matrix)。为了实现矩阵的运算,文中对输入也做了对应的处理,即对输入 x x x 做 n n n 份复制(即文中说的副本),实现宽度上的扩展。
理论优势
- 深度上可以更灵活地融入特征
- 在残差连接中经过Attention或者FFN之后的结果,一直以权重为1的方式与输入进行加和
- HC中通过引用可学习的 β \beta β作为加和权重,实现了模型“自主”决策融合比例,比传统残差连接具有更强的表达能力
- 能够同时建模多种深度连接方式,此外使得同一层内的隐藏状态可以相互交换信息
- 通过对输入进行扩展,实现了同一层中存在不同的隐藏状态,增加了隐藏层的表征能力
- 通过可学习的 α \alpha α 实现同一层不同隐藏状态之间的融合和交互
效果体现
- 语言模型预训练(1B、7B、MoE):
- 在相同训练 token 数量下,收敛速度提升 1.8 倍,下游任务准确率提升 6 个百分点(如 ARC-Challenge)。
- 模型更稳定,训练过程中无 loss spike。
- 视觉任务(ImageNet 分类与生成):
- ViT 模型中,Top-1 准确率提升最多 2.69%。
- DiT 图像生成中,性能接近参数量大 50% 的模型。
具体实现
n=2的情况

上图(b)实现了基于Pre-Norm的HC(n=2)的具体实现。我们也可以从公式上实现HC的表示。此外,当我们搞清楚了n=2的计算方式,也就能进一步搞明白论文中针对任意n的表达公式。
为了跟图中的参数对齐,后续将使用h来代替x,两者在本文中的含义是一致的,即都表示某一隐藏层的输入。
HC n=2 的具体实现:
- 在进入第一层之前对 h h h 进行扩展(即复制),图中扩展为两份,即 h 1 h_1 h1 和 h 2 h_2 h2
- 对每个隐藏状态进行Attention和FFN(这里用 F F F表示)计算:
h 1 = β 1 ∗ F ( α 1 , 0 ∗ h 1 + α 2 , 0 ∗ h 2 ) + ( α 1 , 1 ∗ h 1 + α 2 , 1 ∗ h 2 ) h 2 = β 2 ∗ F ( α 1 , 0 ∗ h 1 + α 2 , 0 ∗ h 2 ) + ( α 1 , 2 ∗ h 1 + α 2 , 2 ∗ h 2 ) h_1 = \beta_1 * F(\alpha_{1,0} * h_1 + \alpha_{2,0} * h_2) + (\alpha_{1,1} * h_1 + \alpha_{2,1} * h_2) \\ h_2 = \beta_2 * F(\alpha_{1,0} * h_1 + \alpha_{2,0} * h_2) + (\alpha_{1,2} * h_1 + \alpha_{2,2} * h_2) h1=β1∗F(α1,0∗h1+α2,0∗h2)+(α1,1∗h1+α2,1∗h2)h2=β2∗F(α1,0∗h1+α2,0∗h2)+(α1,2∗h1+α2,2∗h2) - 在经过多层block(一个block由一个attention和一个FFN组成)计算后,再将两个隐藏状态进行求和,得到最终的输出:
h = h 1 + h 2 h = h_1 + h_2 h=h1+h2
这里也可以写出n=2的HC矩阵表达:
H C = ( 0 β 1 β 2 α 1 , 0 α 1 , 1 α 1 , 2 α 2 , 0 α 2 , 1 α 2 , 2 ) HC = \left( \begin{array}{ccc} 0 & \beta_1 & \beta_2 \\ \alpha_{1,0} & \alpha_{1,1} & \alpha_{1,2} \\ \alpha_{2,0} & \alpha_{2,1} & \alpha_{2,2} \\ \end{array} \right) HC= 0α1,0α2,0β1α1,1α2,1β2α1,2α2,2
通用表达
静态HC(static hyper-connections,SHC)

特别地, B ∈ R 1 × n , A m ∈ R n × 1 , A r ∈ R n × n B \in \mathbb{R}^{1 \times n}, \space A_m \in \mathbb{R}^{n \times 1}, \space A_r \in \mathbb{R}^{n \times n} B∈R1×n, Am∈Rn×1, Ar∈Rn×n,公式中的 H H H 是多个隐藏状态 h h h 的组成。一般 h h h 的维度为 [ b a t c h , s e q , d i m ] [batch,\ seq,\ dim] [batch, seq, dim],经过扩展后 H H H 的维度变为 [ b a t c h , s e q , n , d i m ] [batch,\ seq,\ n,\ dim] [batch, seq, n, dim]。
从上面也可以看出来,需要学习的就只有三个向量矩阵: B ∈ R 1 × n , A m ∈ R n × 1 , A r ∈ R n × n B \in \mathbb{R}^{1 \times n}, \space A_m \in \mathbb{R}^{n \times 1}, \space A_r \in \mathbb{R}^{n \times n} B∈R1×n, Am∈Rn×1, Ar∈Rn×n,所以一个block中新增的参数量就只有:parm = 2 * (n + n + n*n) = 2n(n+2)。这里乘以2是因为一个block中包含一个attention和一个FFN。论文中推荐的n为4,所以一层block新增的参数量仅有48个。
动态HC(dynamic hyper-connections,DHC)
一句话总结动态HC与静态HC的差别:静态HC中每一层中的A/B参数与输入的H是完全无关的,即训练一旦完成,每层的A/B参数就会固定。
但这样就会存在一个缺点:网络无法根据每层输入的H情况来“动态”调整多个隐藏状态h之间的融合权重。 DHC 引入了输入依赖性 (Input-Dependent)。对于某些 Token,模型可能觉得“这一层的计算很重要”,于是加大深度连接的权重;对于另一些 Token,模型可能觉得“保持原样更好”,于是减小权重。这类似于给每个 Token 配备了一个智能的交通指挥员。
注意:
- 通常取 H 的第一行(或某种聚合形式),经过 LayerNorm 后作为当前层的输入向量(公式10)
- Wβ, Wm, Wr 是可学习的线性变换矩阵,即一个全连接层(Linear Layer)
- sα, sβ 是初始值很小的缩放因子(如 0.01)
- Tanh 的作用:论文特别强调(在 Visualization & Analysis 部分),Tanh 激活函数对于 DHC 的训练稳定性至关重要。如果去掉 Tanh,模型性能会显著下降甚至无法收敛。 它将权重限制在 (−1,1) 之间,防止信号爆炸。
论文中,对动态HC的参数计算也写的很详细,这里就直接贴一下:
总结
这里的数据都是以文中的OLMo-1B(n=4)为代表进行说明的
- 参数量上:不管是静态HC还是动态HC,增加参数量均可忽略不计:针对一个1B的模型,静态HC增加的参数量约为0.00007%,而动态HC也仅为0.03%左右。
- 计算复杂度上:无论静态还是动态超连接,引入的额外 FLOPs 均可忽略不计:针对一个1B的模型,静态HC增加的计算量约为0.127%,而动态HC也仅为0.2%左右。
- 显存消耗上:由于HC的引入,显存上还是有明显的增加,这一点不像参数量和计算量:针对一个1B的模型,静态HC和动态HC增加的显存均约为26%左右。 关于显存明显增加的问题,这一点受到了ICLR审稿人的质疑,针对这一点seed该团队进一步提出了Frac-Connections(FC),来减少计算量和显存开销。
拓展阅读:Frac-Connections(FC)
原理
为了减少HC的计算量和显存开销,字节seed团队在25年3月份进一步提出了HC的优化版本: Frac-Connections: Fractional Extension of Hyper-Connections
一句话总结其特点:不像HC中对输入x进行n份复制来进行扩展,而是通过对输入x进行split:拆分成 m = 1 / n m = 1/ n m=1/n 份,这里的 n n n 为“扩展率”。可以认为FC是HC的一种拓展,HC中要求 n > 1 n > 1 n>1,把隐藏状态h复制n份。而FC定义了 0 < n < 1 0 < n < 1 0<n<1 的时候,这个时候是把隐藏状态h split成m份。
值得注意的是:FC是表征能力和计算量之间的一种权衡方法,FC虽然减少了计算量,但也降低了模型的表征能力。论文原话:
The similarity between adjacent hidden states in FC lies between that of HC and baseline (Pre-Norm), indicating that their representational capacity follows the order: HC > FC > Pre-Norm.
实现

从图(c)中可以看出,具体操作步骤为:
- 对输入h进行split操作,分为m份子向量
- 宽度上进行融合,执行cat后进行attention/FFN计算
- 将经过attention/FFN计算的结果进行对应的split,然后进行深度融合
- 最后将所有隐藏状态h进行拼接(cat)
优势
这里的数据都是以文中的OLMo-1B(n=4)为代表进行说明的
- 参数量:动态FC增加的参数量为0.014%,比动态HC减少一半
- 计算量:动态FC增加的计算量为0.044%,比动态HC减少约5倍
- 显存消耗:文中没有明确统计。论文摘要中明确提到FC是为了reducing memory consumption,但这关键的一点却没有通过实验数据来证明,这可能也说明了FC也并没有达到显存显著节省的效果。
流形约束HC是什么
很多公众号都在讲DeepSeek提出来的流形约束超连接(mHC: Manifold-Constrained Hyper-Connections)的特性:
- 方法:将残差连接矩阵变换为“双随机矩阵”,映射到一个Birkhoff 多胞形(Birkhoff polytope)上。
- 目的:维持了“恒等变换特性”、节省了内存开销。
要理解这里面的目的和做法,我们需要先深入理解一下字节seed提出的HC存在什么问题?
HC存在什么问题
图解

引用HC论文中的一张图,左边是传统的残差连接示意图,右边是HC示意图。通过对比可以发现:
- 残差连接:不管神经网络block堆积多深,始终有一条畅通无阻的“高速公路”让输入h直达输出。这种特性就是“恒等变换”特性。这个特性解决了当网络变深后导致的梯度消失问题
- 超连接(HC):这条“高速公路”上出现很多“阻碍”(即一些需要学习的参数 α \alpha α),这些阻碍在网络变深后会导致两种可能的问题:
- 如果参数 α < 1 \alpha < 1 α<1,一旦累积多层后,这条“公路”就像断了一样,把输入逐渐置为了0,到网络深层就接收不到浅层的信息了。这就像移除了残差连接一样,会导致梯度消失的问题出现。
- 如果参数 α > 1 \alpha > 1 α>1,一旦累积多层后,这条“公路”就像一个放大器,把输入逐渐放大了很多倍(deepseek论文里面有提到说峰值存在放大3000倍的情况),到网路深层接收到的浅层信息就“不保真”,数值巨大,会导致梯度爆炸的问题出现。
所以这里就引出了 mHC 论文的核心观点和做法:在HC里面,继续维持这条“公路”通畅,即参数 α \alpha α 一直保持在数值1 附近。在实际操作中,参数 α \alpha α(即HC论文中的 A r A_r Ar)是一个 n × n n \times n n×n 的矩阵,让矩阵维持这个特性的做法就是将 A r A_r Ar 通过某种变换,变成列和为1、行和为1的矩阵,这个矩阵就叫“双随机矩阵”(行和为1、列和为1是双随机矩阵的最大特性)。实现方法就是 Sinkhorn-Knopp 算法。
理论证明

注意mHC中的符号跟HC论文中是对不上的,对应关系是: H r e s = A r , H p r e = A m , H p o s t = B H^{res} = A_r, \ H^{pre} = A_m, \ H^{post} = B Hres=Ar, Hpre=Am, Hpost=B
mHC也用过消融实验证明, H r e s H^{res} Hres 矩阵对最终的模型效果影响最大,所以要优先处理 H r e s H^{res} Hres 矩阵。
这是mHC论文中对多层HC堆积的公式表达,红框中就是残差流上多个 H r e s H^{res} Hres 矩阵的累积结果。这个结果在网络多层累积后,会变得非常小或者非常大,导致梯度消失或者爆炸的问题出现。

mHC中也通过实验具体证明了这一点:
- 在一个27B模型的训练上,训练12k步后loss就飞了。
- 在网络深层(l=50), 累积的 H r e s H^{res} Hres 矩阵数值最高达到了3000,远远大于理想的1。mHC就是为了让这个累积结果一直维持在1附近。
mHC的具体实现
搞懂原理后,再来看mHC的具体实现就非常简单了。
论文中的公式表达:
解释:
- 公式(7)就是HC论文中动态HC的实现,唯一不同的是输入RMSNorm的x,进行了flatten处理,让输入信息全部参与参数计算,而不是原文中选择第一个维度。
- 公式(8)就是mHC的实现,在公式(7)的基础上仅仅加了一个矩阵变换,其中 σ \sigma σ 代表是sigmoid函数,对于相对不重要的参数矩阵,仅进行简单的sigmoid变换。而对于重要的 H r e s H^{res} Hres 矩阵,做了 Sinkhorn-Knopp 变换 。
计算效率优化
DeepSeek的工作除了对 H r e s H^{res} Hres 矩阵的变换外,他们还通过一些列工程的手段来提升了模型的计算效率,同时解决了HC和FC中存在的显存开销问题,这一点也是该工作的一个重大意义。
由于工程优化不是本文的重点,这里就不再进行具体的展开。这里引用字节提出的“Hyper-Connections”,被DeepSeek救活了?公众号比较简洁的总结:
对于 Transformer 架构,HC 带来的额外 I/O 开销约为 ( 5 n + 1 ) C (5n+1)C (5n+1)C。为了解决这个问题,DeepSeek 实施了三项系统优化:
- Kernel Fusion:使用 TileLang 手写算子,将 RMSNorm、MatMul 和加法操作深度融合,减少显存读写次数。
- Recomputing(重计算):在前向传播中丢弃中间激活值,反向传播时重新计算。通过优化 Block Size ( L r ≈ n L / ( n + 2 ) L_r \approx \sqrt{nL/(n+2)} Lr≈nL/(n+2)),降低显存占用。
- DualPipe Overlap:在流水线并行中,将 mHC 带来的额外计算与通信时间重叠掩盖。
最终,在 n=4 的配置下,mHC 在大规模训练中的额外时间开销被压缩至 6.7%,使其具备了工程可用性。
思考

虽然mHC通过对 H r e s H^{res} Hres 矩阵进行约束,维持了恒等变换的特性,但这似乎也会弱化模型的表征能力,即弱化不同隐藏状态之间交互的作用。HC的论文中也指出如果 H r e s H^{res} Hres 矩阵变为了一个单位矩阵,但不同的隐藏状态h之间都没有任何交互了,这就违背了HC的设计初衷。
从mHC论文中的图8中,可以发现在深层的时候(l=30/60)的时候, H r e s H^{res} Hres 矩阵已经非常像一个单位矩阵了,所以在深层的时候,不同h之间的交互“似乎”已经变得非常的弱了。
所以,个人觉得:应该在限制 H r e s H^{res} Hres 矩阵满足双随机矩阵特性的同时,也要让其不要直接堕化为一个简单的单位矩阵。 也就是在满足恒等变换的同时也不要弱化不同隐藏状态h的交互。这里可以通过正则化等操作来对这些可学习的参数进行一定的限制。
mHC的价值和意义是什么
价值
- 理论意义:mHC的提出使得HC可以应用在更深、更大的模型上,这在当前的大模型时代是一个非常刚需的要求。简单来说,残差连接的提出使得卷积block可以堆叠的更深,而mHC的提出使得带HC的block可以堆叠的更深。
- 工程意义:从工程的角度解决了HC/FC一直被诟病的显存开销的问题,可以推动HC的研究。
- HC变革了模型架构设计:之前的大模型为了增强表征能力,只能通过加大网络深度,或者每个token使用更大的维度(即dim),但这两个方法都会显著的导致模型的计算复杂度被提升。HC的提出,引出了一个新的维度,即残差流宽度(n),它的优势在于扩大n,不会带来显著的参数量和计算复杂度的提升。
- mHC会不会引发新的一轮模型架构革新?这是mHC论文在摘要中有明确提到的一个期望。目前的mHC只是偏理论的一些证明,可能很快DeepSeek就会放出mHC在实际应用上的价值,可能就会应用在他们的下一版本发布的模型上。
未来的研究
- 字节seed对HC的探索仅在7B的模型上,deepseek提出的mHC在27B的模型上验证了可能,那在更大的模型上,mHC是否还同样有效呢?
- 残差流的宽度探索(即n参数),目前都是使用默认的设置4,那更大的n会有什么问题呢?效果会更上一层么?
- HC架构应用在更多的场景模型上,当前还是主要在语言基础模型上,图像理解、生成等大模型上表现会如何?
参考
- 2026 第一枪:字节提出的超连接,被 DeepSeek 救活了
- 字节提出的“Hyper-Connections”,被DeepSeek救活了?
- 刚刚,梁文锋署名,DeepSeek元旦新论文要开启架构新篇章
- 详细解读DeepSeek新年的第一篇论文,他们就是这个时代的真神。
- 如何理解 DeepSeek 最新提出的 mHC 架构?
- Deep Residual Learning for Image Recognition
- Hyper-Connections
- Frac-Connections: Fractional Extension of Hyper-Connections
- mHC: Manifold-Constrained Hyper-Connections
附录
Pre-Norm Block 实现
class PreNormBlock(nn.Module):
def __init__(self, d_model):
super().__init__()
self.attn = MultiHeadAttention(d_model)
self.ffn = FeedForward(d_model)
# 定义两个归一化层,分别用于 Attn 和 FF 之前
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# 1. Self-Attention 子层
# 先 Norm,再进子层,最后加残差
residual = x
x = self.norm1(x)
x = self.attn(x)
x = x + residual # 残差连接
# 2. Feed-Forward 子层
# 同样:先 Norm,再进子层,最后加残差
residual = x
x = self.norm2(x)
x = self.ffn(x)
x = x + residual # 残差连接
return x
Post-Norm Block 实现
class PostNormBlock(nn.Module):
def __init__(self, d_model):
super().__init__()
self.attn = MultiHeadAttention(d_model)
self.ffn = FeedForward(d_model)
# 定义两个归一化层,分别用于 Attn 和 FF 之后
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# 1. Self-Attention 子层
# 先进子层,加残差,最后 Norm
residual = x
x = self.attn(x)
x = residual + x # 残差连接
x = self.norm1(x) # Post-Norm
# 2. Feed-Forward 子层
# 同样:先进子层,加残差,最后 Norm
residual = x
x = self.ffn(x)
x = residual + x # 残差连接
x = self.norm2(x) # Post-Norm
return x
一个完整的Hyper-Connection网络demo实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
# 简化版,实际应用通常使用 nn.MultiheadAttention
self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=4, batch_first=True)
def forward(self, x):
# x shape: [batch_size, seq_len, d_model]
output, _ = self.attn(x, x, x)
return output
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=None):
super().__init__()
d_ff = d_ff or 4 * d_model
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
def forward(self, x):
return self.net(x)
class HyperConnection(nn.Module):
def __init__(self, dim, rate, layer_id, dynamic, device=None):
super(HyperConnection, self).__init__()
self.rate = rate # 对应论文中的扩展率n
self.layer_id = layer_id # 网络第几层
self.dynamic = dynamic # 是否使用动态HC
self.static_beta = nn.Parameter(torch.ones((rate,), device=device)) # 对应论文中的矩阵B [n]
init_alpha0 = torch.zeros((rate, 1), device=device) # 对应论文中的矩阵Am [n, 1]
init_alpha0[layer_id % rate, 0] = 1.0 # 对矩阵Am初始化
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye((rate), device=device)], dim=1)) # 对应论文中的矩阵Am+Ar [n, n+1]
if self.dynamic: # 动态HC的初始化,对应论文中的公式(10)-(13),注意这里的alpha是对Am和Ar矩阵拼接之后的表示
self.dynamic_alpha_fn = nn.Parameter(torch.zeros((dim, rate+1), device=device))
self.dynamic_alpha_scale = nn.Parameter(torch.ones(1, device=device) * 0.01)
self.dynamic_beta_fn = nn.Parameter(torch.zeros((dim, ), device=device))
self.dynamic_beta_scale = nn.Parameter(torch.ones(1, device=device) * 0.01)
self.layer_norm = nn.LayerNorm(dim)
# 宽度连接
def width_connection(self, h):
if self.dynamic:
norm_h = self.layer_norm(h)
if self.dynamic:
wc_weight = norm_h @ self.dynamic_alpha_fn
wc_weight = F.tanh(wc_weight)
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
alpha = dynamic_alpha + self.static_alpha[None, None, ...]
else:
alpha = self.static_alpha[None, None, ...]
if self.dynamic:
dc_weight = norm_h @ self.dynamic_beta_fn
dc_weight = F.tanh(dc_weight)
dynamic_beta = dc_weight * self.dynamic_beta_scale
beta = dynamic_beta + self.static_beta[None, None, ...]
else:
beta = self.static_beta[None, None, ...]
mix_h = alpha.transpose(-1, -2) @ h
return mix_h, beta
# 深度连接
def depth_connection(self, mix_h, h_o, beta):
h = torch.einsum("blh,bln->blnh", h_o, beta) + mix_h[..., 1:, :]
return h
class HyperConnectionBlock(nn.Module):
def __init__(self, dim, rate, layer_id, dynamic, device=None):
super(HyperConnectionBlock, self).__init__()
self.atten_hyper_connection = HyperConnection(dim=dim, rate=rate, layer_id=layer_id, dynamic=dynamic, device=device)
self.ffn_hyper_connection = HyperConnection(dim=dim, rate=rate, layer_id=layer_id, dynamic=dynamic, device=device)
# 定义两个归一化层,分别用于 Attn 和 FF 之前
self.attn_norm = nn.LayerNorm(dim)
self.ffn_norm = nn.LayerNorm(dim)
# 定义两个子模块,分别用于 Attention 和 FFN
self.self_attention = MultiHeadAttention(d_model=dim)
self.feed_forward = FeedForward(d_model=dim)
# 定义一个 Dropout 层
self.dropout = nn.Dropout(0.1)
def forward(self, h):
# Attention Block
mix_h, beta = self.atten_hyper_connection.width_connection(h)
h = self.attn_norm(mix_h[...,0,:])
h = self.self_attention(h)
h = self.atten_hyper_connection.depth_connection(mix_h, self.dropout(h), beta)
# FFN Block
mix_h, beta = self.ffn_hyper_connection.width_connection(h)
h = self.ffn_norm(mix_h[...,0,:])
h = self.feed_forward(h)
h = self.ffn_hyper_connection.depth_connection(mix_h, self.dropout(h), beta)
return h
class HyperConnectionModel(nn.Module):
def __init__(self, vocab_size, d_model, n_layers, n_hyper=4, dynamic=True, device='cpu'):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.n_hyper = n_hyper
self.d_model = d_model
# 堆叠 HC Blocks
self.layers = nn.ModuleList([
HyperConnectionBlock(d_model, n_hyper, i, dynamic=dynamic, device=device) for i in range(n_layers)
])
self.final_norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size)
def forward(self, input_ids):
# 1. Embedding
x = self.embedding(input_ids) # [batch, seq, d_model]
# 2. 初始化 Hyper Hidden States (Algorithm 1: Initialization)
# 将 x 复制 n 次形成初始矩阵 H
batch_size, seq_len, _ = x.shape
H = x.unsqueeze(2).expand(-1, -1, self.n_hyper, -1).clone() # [b, s, n, d]
# 3. 通过所有层
for layer in self.layers:
H = layer(H)
# 4. 最终输出聚合 (Algorithm 1: Final Output)
# 论文提到: Sum rows of H
H_final = H.sum(dim=2) # Sum over n dimension -> [b, s, d]
# 5. 最终分类
out = self.final_norm(H_final)
logits = self.head(out)
return logits
# --- 测试代码 ---
def test_hc_implementation():
batch_size = 2
seq_len = 10
vocab_size = 100
d_model = 32
n_layers = 2 # 网络层数
n_hyper = 4 # 论文推荐值
dynamic = True # 是否使用动态HC
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HyperConnectionModel(vocab_size, d_model, n_layers=n_layers, n_hyper=n_hyper, dynamic=dynamic, device=device).to(device)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)
output = model(input_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Hyper-Connection Expansion Rate (n): {n_hyper}")
print(f"Dynamic HC: {dynamic}")
print(f"device: {device}")
print(f"Output Logits shape: {output.shape}")
# 验证是否包含 NaN (检查 tanh 稳定性)
if torch.isnan(output).any():
print("Error: Output contains NaNs!")
else:
print("Pass: Forward pass successful.")
if __name__ == "__main__":
test_hc_implementation()
Sinkhorn-Knopp 算法
更多推荐


所有评论(0)