大模型中常见的算法
自注意力:通过 Q/K/V 计算元素间关联,是 Transformer 的核心;MLM:通过掩码预测训练模型的上下文理解能力,是 BERT 等模型的基础;束搜索:平衡生成质量与效率,广泛用于文本生成任务;INT8 量化:通过降低精度减少资源占用,是大模型部署的关键优化手段。
1. 自注意力机制(Self-Attention)
原理:计算序列中每个元素与其他元素的关联权重,动态关注重要信息。
示例:在句子 “猫追狗,它跑得很快” 中,模型通过注意力机制判断 “它” 更可能指代 “猫” 还是 “狗”
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
"""缩放点积注意力"""
d_k = Q.size(-1) # 特征维度
# 计算注意力分数:Q与K的转置相乘,再缩放
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# 应用掩码(可选,如在 decoder 中防止关注未来信息)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 计算注意力权重(softmax归一化)
attn_weights = F.softmax(scores, dim=-1)
# 加权求和得到输出
output = torch.matmul(attn_weights, V)
return output, attn_weights
def self_attention(query, key, value, num_heads=2):
"""多头自注意力"""
batch_size, seq_len, d_model = query.size()
d_k = d_model // num_heads # 每个头的维度
# 线性变换并分多头
Q = torch.nn.Linear(d_model, d_model)(query).view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
K = torch.nn.Linear(d_model, d_model)(key).view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
V = torch.nn.Linear(d_model, d_model)(value).view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)
# 计算每个头的注意力
output, attn_weights = scaled_dot_product_attention(Q, K, V)
# 拼接多头结果
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
return output, attn_weights
# 测试
if __name__ == "__main__":
# 模拟输入:batch_size=1,seq_len=3(3个词),d_model=4(每个词的向量维度)
x = torch.randn(1, 3, 4)
output, attn = self_attention(x, x, x) # Q=K=V,自注意力
print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("注意力权重(第一个头):\n", attn[0, 0]) # 每个词对其他词的关注程度
2. 掩码语言模型(MLM)
原理:随机掩盖输入中的部分 token,让模型预测被掩盖的内容,强制学习上下文理解。
示例:对句子 “北京是中国的 [MASK]”,模型需要预测 [MASK] 为 “首都”。
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleMLM(nn.Module):
def __init__(self, vocab_size=1000, embedding_dim=128):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.fc = nn.Linear(embedding_dim, vocab_size) # 预测token
def forward(self, x):
x = self.embedding(x) # 词嵌入
logits = self.fc(x) # 输出每个位置的预测概率
return logits
def mask_tokens(input_ids, mask_prob=0.15, vocab_size=1000):
"""随机掩码输入token"""
batch_size, seq_len = input_ids.shape
mask = torch.rand(batch_size, seq_len) < mask_prob # 掩码位置
masked_input = input_ids.clone()
# 80%概率替换为[MASK],10%随机替换,10%保留原token
for i in range(batch_size):
for j in range(seq_len):
if mask[i, j]:
r = torch.rand(1).item()
if r < 0.8:
masked_input[i, j] = 1 # [MASK]的ID设为1
elif r < 0.9:
masked_input[i, j] = torch.randint(2, vocab_size, (1,)).item() # 随机token
return masked_input, mask # 返回掩码后的输入和掩码位置
# 测试
if __name__ == "__main__":
model = SimpleMLM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
# 模拟输入:batch_size=2,seq_len=5(每个数字代表一个token)
input_ids = torch.randint(2, 1000, (2, 5)) # 避免[PAD]和[MASK]
masked_input, mask = mask_tokens(input_ids)
# 训练
model.train()
logits = model(masked_input)
# 只计算被掩码位置的损失
loss = criterion(logits[mask], input_ids[mask])
loss.backward()
optimizer.step()
print("掩码位置:", mask)
print("原始输入:", input_ids)
print("掩码后输入:", masked_input)
print("预测结果(前2个位置):", logits[0, :2].argmax(dim=-1))
3. 束搜索(Beam Search)
原理:生成文本时保留前 K 个概率最高的候选序列,逐步扩展并淘汰低分序列,平衡生成质量与效率。
示例:生成 “我喜欢吃” 的后续文本时,束宽 K=2 可能保留 “我喜欢吃苹果” 和 “我喜欢吃米饭” 两个候选。
import torch
import torch.nn.functional as F
from collections import defaultdict
def beam_search(initial_sequence, model, max_len=5, beam_width=2, eos_token=3):
"""
束搜索生成文本
initial_sequence: 初始序列(如[CLS] token)
model: 生成模型(输入序列,输出下一个token的概率)
eos_token: 结束符token ID
"""
# 初始化束:(序列, 累积概率)
beams = [(initial_sequence, 0.0)]
for _ in range(max_len):
candidates = []
for seq, score in beams:
# 如果序列已结束,直接加入候选
if seq[-1] == eos_token:
candidates.append((seq, score))
continue
# 模型预测下一个token的概率
input_tensor = torch.tensor(seq).unsqueeze(0) # 增加batch维度
logits = model(input_tensor) # 假设输出形状:(1, seq_len, vocab_size)
next_logits = logits[0, -1, :] # 取最后一个位置的预测
next_probs = F.softmax(next_logits, dim=-1) # 转为概率
# 取前beam_width个概率最高的token
top_probs, top_tokens = torch.topk(next_probs, beam_width)
for token, prob in zip(top_tokens, top_probs):
new_seq = seq + [token.item()]
new_score = score + torch.log(prob).item() # 累积对数概率(避免下溢)
candidates.append((new_seq, new_score))
# 按分数排序,保留前beam_width个候选
candidates.sort(key=lambda x: x[1], reverse=True)
beams = candidates[:beam_width]
# 如果所有束都已结束,提前退出
if all(seq[-1] == eos_token for seq, _ in beams):
break
# 返回分数最高的序列
return max(beams, key=lambda x: x[1])[0]
# 模拟生成模型
class DummyGenerator(nn.Module):
def forward(self, x):
vocab_size = 10
return torch.randn(x.shape[0], x.shape[1], vocab_size) # 随机输出,仅作示例
# 测试
if __name__ == "__main__":
model = DummyGenerator()
initial_seq = [0] # 初始序列(如[CLS])
generated_seq = beam_search(initial_seq, model, max_len=5, beam_width=2)
print("生成序列:", generated_seq)
4. 量化算法(INT8 量化)
原理:将 32 位浮点数参数转换为 8 位整数,减少显存占用并加速推理,通过缩放因子保留精度。
示例:将权重从 FP32(如 1.2345)量化为 INT8(如 123),推理时再通过缩放因子还原。
import torch
def quantize_int8(tensor):
"""将FP32张量量化为INT8"""
# 计算缩放因子:将张量范围映射到[-127, 127]
min_val = tensor.min().item()
max_val = tensor.max().item()
scale = (max_val - min_val) / 254 # 254 = 127 - (-127)
# 零偏移:确保0在量化后仍为0
zero_point = int(-min_val / scale + 127)
zero_point = max(0, min(255, zero_point)) # 限制在[0, 255]
# 量化:round( tensor / scale + zero_point - 127 )
quantized = torch.round(tensor / scale + zero_point - 127)
quantized = torch.clamp(quantized, -127, 127).to(torch.int8) # 截断到INT8范围
return quantized, scale, zero_point
def dequantize_int8(quantized_tensor, scale, zero_point):
"""将INT8张量反量化为FP32"""
return (quantized_tensor.to(torch.float32) + 127 - zero_point) * scale
# 测试
if __name__ == "__main__":
# 模拟模型权重(FP32)
weights = torch.tensor([-1.2, 0.5, 3.14, -0.8, 2.7])
print("原始权重:", weights)
# 量化
quantized, scale, zero_point = quantize_int8(weights)
print("量化后(INT8):", quantized)
print("缩放因子:", scale)
print("零偏移:", zero_point)
# 反量化
dequantized = dequantize_int8(quantized, scale, zero_point)
print("反量化后:", dequantized)
print("量化误差:", torch.mean(torch.abs(weights - dequantized)))
总结
以上代码展示了大模型核心算法的简化逻辑:
- 自注意力:通过 Q/K/V 计算元素间关联,是 Transformer 的核心;
- MLM:通过掩码预测训练模型的上下文理解能力,是 BERT 等模型的基础;
- 束搜索:平衡生成质量与效率,广泛用于文本生成任务;
- INT8 量化:通过降低精度减少资源占用,是大模型部署的关键优化手段。
更多推荐
所有评论(0)