Transformer升级之路:从“原始”机制到V3,注意力机制的5次关键优化

Transformer模型自2017年“Attention Is All You Need”横空出世以来,便以前所未有的方式重塑了自然语言处理(NLP)乃至整个深度学习领域。它的强大之处,很大程度上归功于其核心的自注意力(Self-Attention)机制,使模型能够高效地捕捉序列中的长距离依赖关系。

然而,“完美”是相对的,而进步永无止境。Transformer的“祖师爷”虽然强大,但在面对海量数据、超长序列、效率瓶颈等挑战时,其原始的注意力设计也暴露出一些局限性。由此,研究人员们围绕注意力机制进行了一系列精妙的优化和创新,使得Transformer模型如同“升级打怪”一般,逐渐变得更强大、更高效。

本文将带您回顾Transformer的升级之路,聚焦于注意力机制的5次关键优化,从原始的 Scaled Dot-Product Attention 出发,一路探索到更现代、更高效的变体,并辅以代码概念来揭示其核心改进。

1. 基础:Scaled Dot-Product Attention——万物的起点

我们首先回顾Transformer的基石——Scaled Dot-Product Attention。它解决了传统RNN在处理长序列时的信息瓶颈,通过一种“全局”的、基于点积相似度的对齐方式,让模型能够灵活地从输入序列中抽取相关信息。

核心逻辑:

输入: 三组向量——查询(Query, Q)、键(Key, K)、值(Value, V)。通常,这三者都由相同的输入序列(例如,词嵌入)通过不同的线性变换得到。

相似度计算: 计算Q与K之间的点积,结果是一个“分数矩阵”,表示每个查询向量与每个键向量的匹配程度。

缩放: 将分数矩阵除以 dk\sqrt{d_k}dk(键向量维度的平方根),以稳定梯度,防止点积过大。

Softmax 归一化: 对分数矩阵的每一行应用Softmax函数,将其转换为注意力权重,权重和为1,代表“应该分配多少注意力”给每个值向量。

加权求和: 用注意力权重乘以值(V)向量,得到加权平均后的输出。

数学表达:

Attention(Q,K,V)=Softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=Softmax(dkQKT)V

原始Attention的局限性:

计算复杂度: O(seq_len2⋅d)O(\text{seq\_len}^2 \cdot d)O(seq_len2⋅d),时间与内存都随序列长度呈平方增长,是处理长序列的瓶颈。

内存占用: 需要存储一个seq_len×seq_len\text{seq\_len} \times \text{seq\_len}seq_len×seq_len的注意力权重矩阵,对内存敏感。

代码概念(Python/NumPy):

<PYTHON>

import numpy as np

def softmax(x):

e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))

return e_x / e_x.sum(axis=-1, keepdims=True)

def scaled_dot_product_attention(Q, K, V, d_k):

# Q: (batch_size, seq_len_q, d_k)

# K: (batch_size, seq_len_k, d_k)

# V: (batch_size, seq_len_k, d_v)

# 1. Calculate scores: Q @ K.T

K_T = np.swapaxes(K, 1, 2) # (batch_size, d_k, seq_len_k)

scores = np.matmul(Q, K_T) / np.sqrt(d_k) # (batch_size, seq_len_q, seq_len_k)

# 2. Apply softmax

attention_weights = softmax(scores) # (batch_size, seq_len_q, seq_len_k)

# 3. Weighted sum of values

output = np.matmul(attention_weights, V) # (batch_size, seq_len_q, d_v)

return output, attention_weights

# --- 示例参数 ---

batch_size = 1

seq_len = 10 # 序列长度

d_k = 64

d_v = 128

# 模拟Q, K, V

Q_sample = np.random.randn(batch_size, seq_len, d_k)

K_sample = np.random.randn(batch_size, seq_len, d_k)

V_sample = np.random.randn(batch_size, seq_len, d_v)

output, weights = scaled_dot_product_attention(Q_sample, K_sample, V_sample, d_k)

print(f"Scaled Dot-Product Attention Output Shape: {output.shape}") # (1, 10, 128)

2. 优化1:Multi-Head Attention——“多角度”审视信息

Transformer的原论文就引入了Multi-Head Attention (MHA),这与其说是对Scaled Dot-Product Attention的“升级”,不如说是一种并行化和结构上的改进,但它对模型能力的影响是巨大的。

核心逻辑:

MHA通过将Q, K, V分别线性投影到多个“头”(heads),每个头独立执行Scaled Dot-Product Attention,然后将所有头的输出拼接并再次线性投影,得到最终的输出。

优势:

关注不同表示子空间: 每个注意力头可以学习关注输入序列的不同表示子空间。例如,一个头可能关注词语的语法关系,另一个头可能关注语义相关性。这使得模型能从多个角度理解信息。

提升模型表达能力: 相当于并行执行了多个“更窄”的注意力机制,整体上增强了模型的表达能力。

参数与计算:

如果一个Transformer层有hhh个注意力头,每个头的维度是dkd_kdk(即dmodel=h⋅dkd_{model} = h \cdot d_kdmodel=h⋅dk),那么MHA的总计算量与单个“大维度”注意力头的计算量是相似的(在dkd_kdk和dvd_vdv维度上),但增加了查询、键、值投影的参数量,并且需要额外拼接和最终线性变换的参数。

原始MHA的局限性:

计算量与内存占用: 尽管对每个头来说,计算量是 (n2dk)(n^2 d_k)(n2dk),但总的计算量和参数量随着头数hhh的增加而增加。在推理时,Key-Value(KV)缓存(给每个头都要存储K, V)的内存占用是主要瓶颈,尤其是对于长序列。

概念代码(伪代码,展示并行和拼接):

<PYTHON>

import numpy as np

# 假设 scaled_dot_product_attention 函数可用

# --- 示例参数 ---

batch_size = 1

seq_len = 10

d_model = 128

num_heads = 8

d_k = d_model // num_heads # 每个头的维度

d_v = d_model // num_heads

# 模拟输入 (e.g., combined embeddings)

X = np.random.randn(batch_size, seq_len, d_model)

# 1. Linear Projections (Q, K, V) - conceptually separate for each head

# In practice, these are done with larger matrices and then reshaped

# Q_heads = [np.random.randn(batch_size, seq_len, d_k) for _ in range(num_heads)]

# K_heads = [np.random.randn(batch_size, seq_len, d_k) for _ in range(num_heads)]

# V_heads = [np.random.randn(batch_size, seq_len, d_v) for _ in range(num_heads)]

# A more compact way to think about projections:

W_Q = np.random.randn(d_model, num_heads * d_k)

W_K = np.random.randn(d_model, num_heads * d_k)

W_V = np.random.randn(d_model, num_heads * d_v)

Q_all = np.matmul(X, W_Q) # (batch_size, seq_len, num_heads * d_k)

K_all = np.matmul(X, W_K) # (batch_size, seq_len, num_heads * d_k)

V_all = np.matmul(X, W_V) # (batch_size, seq_len, num_heads * d_v)

# Reshape to separate heads:

Q_heads = Q_all.reshape(batch_size, seq_len, num_heads, d_k)

Q_heads = np.swapaxes(Q_heads, 1, 2) # (batch_size, num_heads, seq_len, d_k)

K_heads = K_all.reshape(batch_size, seq_len, num_heads, d_k)

K_heads = np.swapaxes(K_heads, 1, 2) # (batch_size, num_heads, seq_len, d_k)

V_heads = V_all.reshape(batch_size, seq_len, num_heads, d_v)

V_heads = np.swapaxes(V_heads, 1, 2) # (batch_size, num_heads, seq_len, d_v)

# 2. Parallel Attention Computation

outputs = []

for i in range(num_heads):

head_output, _ = scaled_dot_product_attention(Q_heads[:, i:i+1, :, :],

K_heads[:, i:i+1, :, :],

V_heads[:, i:i+1, :, :],

d_k)

outputs.append(head_output.squeeze(1)) # Remove the extra head dimension

# 3. Concatenate and Final Linear Projection

multi_head_output = np.concatenate(outputs, axis=-1) # (batch_size, seq_len, num_heads * d_v)

# W_O = np.random.randn(num_heads * d_v, d_model)

# final_output = np.matmul(multi_head_output, W_O)

print(f"Multi-Head Attention Output Shape: {multi_head_output.shape}") # (1, 10, 128)

3. 优化2:FlashAttention——内存I/O的革命

原始MHA在处理长序列时的主要瓶颈是内存带宽,而非计算速度。FlashAttention(及其后续版本FlashAttention-2)通过对GPU硬件特性进行深度优化,实现了在不改变数学公式的前提下,显著提升速度和内存效率。

核心改进:

Tiling & Recomputation: 将注意力计算分解为小块(Tiles),并在GPU的SRAM(速度快但容量小)中进行计算。中间的注意力权重矩阵不完全存储到HBM(容量大但速度慢)中,而是通过重计算(Recomputation)的方式,在需要使用时再实时计算。

减少HBM读写: 这种“计算换内存”的策略,大幅减少了读写HBM的次数,而HBM读写往往是GPU计算的瓶颈。

并行化: 进一步优化SRAM的分块策略和Kernel Fusion,减少Kernel Launch开销。

带来的效果:

时间复杂度: 仍然是O(seq_len2⋅d)O(\text{seq\_len}^2 \cdot d)O(seq_len2⋅d),但常数因子大幅降低,对于长序列效率提升惊人。

内存复杂度: 从O(seq_len2)O(\text{seq\_len}^2)O(seq_len2)降低至O(seq_len⋅d)O(\text{seq\_len} \cdot \sqrt{d})O(seq_len⋅d)或O(seq_len⋅d)O(\text{seq\_len} \cdot d)O(seq_len⋅d),使得模型能够处理更长的序列。

代码概念:

FlashAttention的实现高度依赖CUDA编程,用Python/NumPy难以直接模拟其I/O优化。但我们可以理解其思路:

<PYTHON>

# --- 模拟FlashAttention的核心思想 ---

# 这是一个概念示例,实际实现涉及CUDA kernel和SRAM管理

def flash_attention_concept(Q, K, V, d_k, block_size=128):

batch_size, seq_len, _ = Q.shape

# 核心:在SRAM中分块K, V,并迭代计算Q的注意力,再累加结果

# 1. Load Q into SRAM (or process it chunk by chunk)

# 2. Load initial K, V blocks into SRAM

# 3. For each Q block:

# a. Calculate QK^T for the current Q block and K blocks (already in SRAM or loaded)

# b. Compute softmax incrementally (tracking max and sum of exponentials, stored in SRAM)

# c. Compute weighted sum V, and accumulate into an output buffer in SRAM.

# 4. After processing all Q blocks, write the final accumulated output (computed from SRAM) to HBM.

# If intermediate results are too large for SRAM, recomputation happens.

# The actual algorithm avoids storing the full N x N attention matrix

# and minimizes reads/writes to the slower HBM.

# This simulation just shows that computation is broken down.

# The actual performance comes from optimized memory access.

print(f"Simulating FlashAttention with block size: {block_size}")

accumulated_output_parts = [] # Conceptually, this happens in SRAM

for i in range(0, seq_len, block_size):

q_block = Q[:, i:min(i + block_size, seq_len), :]

# In a real FlashAttention kernel, K and V blocks would also be loaded/processed

# Here we just pretend a block output is computed.

# This block computation has to be efficient in SRAM.

block_output = np.random.randn(q_block.shape[0], q_block.shape[1], V.shape[-1])

accumulated_output_parts.append(block_output)

# Combine results (in a mathematically correct way, which involves intermediate values)

final_output = np.concatenate(accumulated_output_parts, axis=1)

return final_output

# --- 示例参数 ---

seq_len_long = 2048 # 长序列

Q_long = np.random.randn(batch_size, seq_len_long, d_k)

K_long = np.random.randn(batch_size, seq_len_long, d_k)

V_long = np.random.randn(batch_size, seq_len_long, d_v)

output_flash = flash_attention_concept(Q_long, K_long, V_long)

print(f"FlashAttention Conceptual Output Shape: {output_flash.shape}") # (1, 2048, 128)

4. 优化3:Grouped-Query Attention (GQA)——KV内存的精打细算

虽然FlashAttention解决了内存I/O瓶颈,但在进行文本生成等自回归任务时,Key-Value (KV) Cache(即Transformer Decoder在处理序列时,为加速计算而缓存前面Token的K和V向量)占用的显存依然巨大,成为限制输入长度的关键因素。Grouped-Query Attention (GQA) 是一种高效的解决方案。

核心思想:

MHA中,每个注意力头有自己的K和V投影。GQA则将多个Query(Q)头分组,让属于同一组的Q头共享同**一套Key(K)和Value(V)的投影。

MHA: HHH个Q头,HHH个K头,HHH个V头。

GQA: HHH个Q头,但只有GGG个K头,GGG个V头(G<HG < HG<H)。每个K/V头服务于 H/GH/GH/G 个Q头。

带来的效果:

减少KV缓存量: 由于K/V参数和缓存的减少,降低了显存占用。

加速推理: 减少了KV缓存的内存带宽需求,在生成长序列时显著提高速度。

性能损失小: 实验表明,GQA在保持接近MHA的性能时,效率提升明显。

代码概念(关注KV共享):

<PYTHON>

# --- 模拟GQA的核心思想 ---

# V1: MHA - 每个Query头有独立的K, V

# V2: GQA - Q头分组,共享K, V

def gqa_concept(Q_heads, K_V_heads, d_k, num_query_heads, num_kv_heads, scaled_dot_product_attention_func):

batch_size, _, seq_len, _ = Q_heads.shape # (batch, num_heads, seq_len, d_k)

num_group_q_heads = num_query_heads // num_kv_heads

all_head_outputs = []

for kv_head_idx in range(num_kv_heads):

# Get the shared K and V for this KV head

K_shared = K_V_heads[:, kv_head_idx:kv_head_idx+1, :, :] # (batch, 1, seq_len, d_k)

V_shared = K_V_heads[:, kv_head_idx:kv_head_idx+1, :, :] # (batch, 1, seq_len, d_v)

# Get all Q heads that share these K/V

q_heads_for_this_kv_group = Q_heads[:,

kv_head_idx * num_group_q_heads :

(kv_head_idx + 1) * num_group_q_heads,

:, :] # (batch, num_group_q_heads, seq_len, d_k)

# Run attention for each Q head in this group, using the SHARED K/V

for q_head_in_group_idx in range(num_group_q_heads):

Q_current_head = q_heads_for_this_kv_group[:, q_head_in_group_idx:q_head_in_group_idx+1, :, :]

# Perform attention for this specific Q head

head_output, _ = scaled_dot_product_attention_func(Q_current_head, K_shared, V_shared, d_k)

all_head_outputs.append(head_output.squeeze(1)) # Remove added head dim

# Concatenate all head outputs and potentially do a final linear projection

# combined_output = np.concatenate(all_head_outputs, axis=-1)

# final_projection = ...

return all_head_outputs # Return list of outputs for each head, conceptually

# --- 示例参数 ---

batch_size = 1

seq_len = 512

num_heads_q = 12

num_heads_kv = 4

d_k = 64

d_v = 64

# 模拟 Q heads, and shared K_V heads

# Q_heads_arr = np.random.randn(batch_size, num_heads_q, seq_len, d_k)

# K_V_heads_arr = np.random.randn(batch_size, num_heads_kv, seq_len, d_k, d_v) # Store K and V together perhaps

# For simplicity, let's just print the concept

print(f"GQA concept: {num_heads_q} Q heads use {num_heads_kv} shared K/V heads.")

print(f"Each shared K/V head serves {num_heads_q // num_heads_kv} Q heads.")

5. Transformer V2 & V3:架构与训练的系统性升级

Transformer V1(原始Transformer)之后,研究者们并没有满足于对注意力机制的单点改进(如MHA),而是开始进行更系统性的架构和训练策略调整。虽然“V2”、“V3”并非官方命名,但我们可以将其视为Transformer系列模型在实践中不断进化的一系列里程碑:

关键改进点(综合体现,可能分散在不同模型如RoBERTa, Longformer, Reformer, Performer, LLaMA等):

局部注意力(Local Attention)/ 稀疏注意力(Sparse Attention):

动机: 解决O(seq_len2)O(\text{seq\_len}^2)O(seq_len2)的二次复杂度。

核心: 让每个Token只关注其邻近的Token(窗口注意力),或只关注局部且稀疏选择的Token(如Longformer的Dilated Sliding Window,或Reformer/Performer的LSH/Kernel方法)。

效果: 将复杂度和内存降低到O(seq_len⋅w⋅d)O(\text{seq\_len} \cdot w \cdot d)O(seq_len⋅w⋅d)或O(seq_len⋅log⁡(seq_len))⋅dO(\text{seq\_len} \cdot \log(\text{seq\_len})) \cdot dO(seq_len⋅log(seq_len))⋅d,使得处理超长序列成为可能。

更高效的位置编码:

原始Transformer: 使用固定的正弦/余弦位置编码。

改进: 相对位置编码(Relative Positional Encoding): RoPE(Rotary Positional Embedding,LLaMA、GPT-NeoX使用)将位置信息编码到Q和K的旋转角度中,在处理长序列和外推时表现更佳。

可学习的位置嵌入: 增加可学习性。

更优的激活函数与归一化:

原始Transformer: ReLU激活,LayerNorm。

改进: GeLU/SwiGLU: GeLU(Gaussian Error Linear Unit)比ReLU具有更好的性能,SwiGLU(一种GLU变体)在很多大型模型(包括LLaMA)中表现更出色。

RMSNorm: Root Mean Square Layer Normalization,比LayerNorm更简单,速度更快,且在很多任务上性能相当。FlashAttention也常与RMSNorm配合使用。

LLaMA等大型模型的注意力变体:

SwiGLU激活 + RMSNorm + RoPE: LLaMA系列将这些优化集成,配合巨量数据和算力,实现了性能和效率的巨大飞跃。

Grouped Query Attention (GQA): 进一步优化KV Cache,提升长序列推理效率。

总结Transformer V2/V3的整体改进方向:

精度(Accuracy): 通过更高级的位置编码、激活函数、归一化,以及更优化的训练数据与策略。

效率(Efficiency): 通过稀疏注意力、FlashAttention、GQA等技术,解决长序列的计算和内存瓶颈。

代码概念(稀疏注意力):

<PYTHON>

# --- 模拟稀疏注意力 ---

# 目标:对于每个Query,只关注K/V中的一个子集

def sparse_attention_concept(Q, K, V, mask):

# Mask 是一个二维布尔矩阵 (seq_len_q, seq_len_k)

# mask[i, j] is True indicates Query i attends to Key j

batch_size, seq_len_q, d_k = Q.shape

_, seq_len_k, _ = K.shape

scores = np.matmul(Q, np.swapaxes(K, 1, 2)) / np.sqrt(d_k)

# Apply mask: set scores to a very small number where attention is not allowed

# This effectively makes their softmax probability zero.

# Need to broadcast mask correctly

masked_scores = np.where(mask[:, np.newaxis, :, :] == False, -1e9, scores) # If mask is (seq_len_q, seq_len_k), need [N, 1, N, N] for batch, heads

# Assuming mask is already in shape for direct application to scores for each head (batch, num_heads, seq_len_q, seq_len_k)

# For simplification, let's assume no heads and mask is (seq_len_q, seq_len_k)

# Simulate creation of a simple mask: attention only to neighbors within window

attention_mask = np.zeros((seq_len_q, seq_len_k), dtype=bool)

window_size = 10

for i in range(seq_len_q):

left = max(0, i - window_size)

right = min(seq_len_k, i + window_size + 1)

attention_mask[i, left:right] = True

# Optionally add some global tokens that attend to everything

# attention_mask[i, global_token_indices] = True

masked_scores_final = np.where(attention_mask == False, -1e9, scores)

attention_weights = softmax(masked_scores_final)

output = np.matmul(attention_weights, V)

return output, attention_weights

# --- 示例参数 ---

seq_len_huge = 4096

Q_sparse = np.random.randn(batch_size, seq_len_huge, d_k)

K_sparse = np.random.randn(batch_size, seq_len_huge, d_k)

V_sparse = np.random.randn(batch_size, seq_len_huge, d_v)

output_sparse, weights_sparse = sparse_attention_concept(Q_sparse, K_sparse, V_sparse)

print(f"Sparse Attention Output Shape: {output_sparse.shape}")

# print("Sparse Attention Weights (non-zero only in windows): ", np.count_nonzero(weights_sparse))

结语:持续进化的注意力

Transformer的注意力机制,从最初的Scaled Dot-Product Attention,到Multi-Head Attention增强了模型的多角度理解,FlashAttention实现了速度和内存的飞跃,GQA则在推理效率上精益求精,而稀疏注意力、RoPE、SwiGLU、RMSNorm等则共同构成了Transformer V2/V3系列模型强大的基石。

每一次优化,都是对计算资源、模型性能和应用场景的深刻理解和巧妙平衡。Transformer的演进之路仍在继续,随着AI研究的深入,我们有理由相信,未来会有更多激动人心的注意力机制变体出现,继续突破AI能力的边界。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐