快速上手大模型:深度学习14(束搜索、注意力、Transformer)
本文系统介绍了自然语言处理中的注意力机制及其相关技术。主要内容包括:1)束搜索算法原理及其在序列预测中的应用;2)注意力机制的核心思想与实现方法,包括Nadaraya-Watson核回归示例;3)自注意力机制和位置编码;4)Transformer架构详解,涵盖多头注意力、前馈网络、层归一化等关键组件;5)完整的编码器-解码器实现及训练过程。文章通过理论分析和代码实现相结合的方式,深入阐述了现代NL
目录
束搜索、使用注意力机制的seq2seq了解,注意力机制、注意力分散、自注意力、Transformer必学。
1 束搜索(beam search)
贪心搜索:用于seq2seq的序列预测。原理是对于输出序列的每一时间步𝑡′, 都将基于贪心搜索从y中找到具有最高条件概率的词元。
穷举搜索:穷举地列举所有可能的输出序列及其条件概率, 然后计算输出条件概率最高的一个。
束搜索:它有一个超参数束宽k。 在时间步1,我们选择具有最高条件概率的𝑘个词元。 这𝑘个词元将分别是𝑘个候选输出序列的第一个词元。 在随后的每个时间步,基于上一时间步的𝑘个候选输出序列, 我们将继续从𝑘|y|个可能的选择中挑出具有最高条件概率的𝑘个候选输出序列。
2 注意力机制
2.1 定义
注意力机制指模型在处理每个元素时,会自动挑选“最重要的信息”,并给重要的东西更大的权重。
核心思想:
2.2 Nadaraya-Watson核回归
注意力机制中权重是模型自动学习出来的。
(1)库
import torch from torch import nn from d2l import torch as d2l(2)生成数据集
n_train = 50 # 训练样本数 x_train, _ = torch.sort(torch.rand(n_train) * 5) # 排序后的训练样本 def f(x): return 2 * torch.sin(x) + x**0.8 y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # 训练样本的输出 x_test = torch.arange(0, 5, 0.1) # 测试样本 y_truth = f(x_test) # 测试样本的真实输出 n_test = len(x_test) # 测试样本数 n_test def plot_kernel_reg(y_hat): d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'], xlim=[0, 5], ylim=[-1, 5]) d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);(3)平均汇聚
y_hat = torch.repeat_interleave(y_train.mean(), n_test) plot_kernel_reg(y_hat)
(4)非参数注意力汇聚
# X_repeat的形状:(n_test,n_train), # 每一行都包含着相同的测试输入(例如:同样的查询) X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train)) # x_train包含着键。attention_weights的形状:(n_test,n_train), # 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重 attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1) # y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重 y_hat = torch.matmul(attention_weights, y_train) plot_kernel_reg(y_hat)(5)注意力权重
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0), xlabel='Sorted training inputs', ylabel='Sorted testing inputs')
3 注意力分散
注意力分数是Query 和Key之间的相关性程度(相似度),分数越高 → 代表越应该关注那个信息,分数越低 → 应该忽略。
4 使用注意力机制的seq2seq
在翻译的每个步骤,Decoder不再只靠一个固定向量,而是学会“看回”输入序列中的所有词,并自动找出最相关的部分,再用于生成当前词。
5 自注意力
5.1 定义
是一种让模型自动关注输入中最重要部分的机制,尤其在 Transformer 架构中最核心。
原理:计算每个词在理解另一个词时应该关注哪些词、关注多少。每个输入向量生成三个向量(Query查询、Key键、Value值)
5.2 位置编码
与CNN/RNN不同,自注意力并没有记录位置信息。位置编码将位置信息注入到输入里,让模型“知道序列中每个元素的位置”的方法。
6 Transformer
6.1 架构
基于编码器-解码器架构来处理序列对,纯注意力架构。
6.2 多头注意力
(1)多头注意力:
对同一key、value、query,抽取不同信息(如短距离关系和长距离关系);多头注意力使用h个独立注意力池化。
(2)有掩码的多头注意力:
解码器对序列中一个元素输出时,不应该考虑元素之后的元素,使用掩码实现。
6.3 逐位前馈网络
将输入形状由(b,n,d)变换成(bn,d),作用两个全连接层,输出形状由(bn,d)变换成(b,n,d),等价于两层核窗口为1的一维卷积层。
6.4 层归一化Add&Norm
批量归一化对每个特征/通道里元素进行归一化,该法不适合序列长度会边的NLP应用;
层归一化对每个样本里的元素进行归一化。
6.5 信息传递
编码器和解码器中块的个数和输出维度都是一样的。
6.6 预测
预测t+1个输出时,解码器中输入前t个预测值。在自注意力中,前t个预测值作为key和value,第t个预测值还作为query。
6.7 代码
6.7.1 多头注意力
import math import torch from torch import nn from d2l import torch as d2l #@save class MultiHeadAttention(nn.Module): """多头注意力""" def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) def forward(self, queries, keys, values, valid_lens): # queries,keys,values的形状: # (batch_size,查询或者“键-值”对的个数,num_hiddens) # valid_lens 的形状: # (batch_size,)或(batch_size,查询的个数) # 经过变换后,输出的queries,keys,values 的形状: # (batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: # 在轴0,将第一项(标量或者矢量)复制num_heads次, # 然后如此复制第二项,然后诸如此类。 valid_lens = torch.repeat_interleave( valid_lens, repeats=self.num_heads, dim=0) # output的形状:(batch_size*num_heads,查询的个数, # num_hiddens/num_heads) output = self.attention(queries, keys, values, valid_lens) # output_concat的形状:(batch_size,查询的个数,num_hiddens) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat) #@save def transpose_qkv(X, num_heads): """为了多注意力头的并行计算而变换形状""" # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens) # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads, # num_hiddens/num_heads) X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) X = X.permute(0, 2, 1, 3) # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) return X.reshape(-1, X.shape[2], X.shape[3]) #@save def transpose_output(X, num_heads): """逆转transpose_qkv函数的操作""" X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1) num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) attention.eval() batch_size, num_queries = 2, 4 num_kvpairs, valid_lens = 6, torch.tensor([3, 2]) X = torch.ones((batch_size, num_queries, num_hiddens)) Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) attention(X, Y, Y, valid_lens).shape
6.7.2 Transformer
import math import pandas as pd import torch from torch import nn from d2l import torch as d2l(1)基于位置的前馈网络
#@save class PositionWiseFFN(nn.Module): """基于位置的前馈网络""" def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs): super(PositionWiseFFN, self).__init__(**kwargs) self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens) self.relu = nn.ReLU() self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs) def forward(self, X): return self.dense2(self.relu(self.dense1(X))) ffn = PositionWiseFFN(4, 4, 8) ffn.eval() ffn(torch.ones((2, 3, 4)))[0](2)残差连接和层归一化
ln = nn.LayerNorm(2) bn = nn.BatchNorm1d(2) X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32) # 在训练模式下计算X的均值和方差 print('layer norm:', ln(X), '\nbatch norm:', bn(X)) #@save class AddNorm(nn.Module): """残差连接后进行层规范化""" def __init__(self, normalized_shape, dropout, **kwargs): super(AddNorm, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) self.ln = nn.LayerNorm(normalized_shape) def forward(self, X, Y): return self.ln(self.dropout(Y) + X) add_norm = AddNorm([3, 4], 0.5) add_norm.eval() add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape(3)编码器
#@save class EncoderBlock(nn.Module): """Transformer编码器块""" def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs): super(EncoderBlock, self).__init__(**kwargs) self.attention = d2l.MultiHeadAttention( key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias) self.addnorm1 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN( ffn_num_input, ffn_num_hiddens, num_hiddens) self.addnorm2 = AddNorm(norm_shape, dropout) def forward(self, X, valid_lens): Y = self.addnorm1(X, self.attention(X, X, X, valid_lens)) return self.addnorm2(Y, self.ffn(Y)) X = torch.ones((2, 100, 24)) valid_lens = torch.tensor([3, 2]) encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5) encoder_blk.eval() encoder_blk(X, valid_lens).shape #@save class TransformerEncoder(d2l.Encoder): """Transformer编码器""" def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs): super(TransformerEncoder, self).__init__(**kwargs) self.num_hiddens = num_hiddens self.embedding = nn.Embedding(vocab_size, num_hiddens) self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout) self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module("block"+str(i), EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias)) def forward(self, X, valid_lens, *args): # 因为位置编码值在-1和1之间, # 因此嵌入值乘以嵌入维度的平方根进行缩放, # 然后再与位置编码相加。 X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) self.attention_weights = [None] * len(self.blks) for i, blk in enumerate(self.blks): X = blk(X, valid_lens) self.attention_weights[ i] = blk.attention.attention.attention_weights return X encoder = TransformerEncoder( 200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5) encoder.eval() encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape(4)解码器
class DecoderBlock(nn.Module): """解码器中第i个块""" def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs): super(DecoderBlock, self).__init__(**kwargs) self.i = i self.attention1 = d2l.MultiHeadAttention( key_size, query_size, value_size, num_hiddens, num_heads, dropout) self.addnorm1 = AddNorm(norm_shape, dropout) self.attention2 = d2l.MultiHeadAttention( key_size, query_size, value_size, num_hiddens, num_heads, dropout) self.addnorm2 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens) self.addnorm3 = AddNorm(norm_shape, dropout) def forward(self, X, state): enc_outputs, enc_valid_lens = state[0], state[1] # 训练阶段,输出序列的所有词元都在同一时间处理, # 因此state[2][self.i]初始化为None。 # 预测阶段,输出序列是通过词元一个接着一个解码的, # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示 if state[2][self.i] is None: key_values = X else: key_values = torch.cat((state[2][self.i], X), axis=1) state[2][self.i] = key_values if self.training: batch_size, num_steps, _ = X.shape # dec_valid_lens的开头:(batch_size,num_steps), # 其中每一行是[1,2,...,num_steps] dec_valid_lens = torch.arange( 1, num_steps + 1, device=X.device).repeat(batch_size, 1) else: dec_valid_lens = None # 自注意力 X2 = self.attention1(X, key_values, key_values, dec_valid_lens) Y = self.addnorm1(X, X2) # 编码器-解码器注意力。 # enc_outputs的开头:(batch_size,num_steps,num_hiddens) Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens) Z = self.addnorm2(Y, Y2) return self.addnorm3(Z, self.ffn(Z)), state decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0) decoder_blk.eval() X = torch.ones((2, 100, 24)) state = [encoder_blk(X, valid_lens), valid_lens, [None]] decoder_blk(X, state)[0].shape class TransformerDecoder(d2l.AttentionDecoder): def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs): super(TransformerDecoder, self).__init__(**kwargs) self.num_hiddens = num_hiddens self.num_layers = num_layers self.embedding = nn.Embedding(vocab_size, num_hiddens) self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout) self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module("block"+str(i), DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i)) self.dense = nn.Linear(num_hiddens, vocab_size) def init_state(self, enc_outputs, enc_valid_lens, *args): return [enc_outputs, enc_valid_lens, [None] * self.num_layers] def forward(self, X, state): X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) self._attention_weights = [[None] * len(self.blks) for _ in range (2)] for i, blk in enumerate(self.blks): X, state = blk(X, state) # 解码器自注意力权重 self._attention_weights[0][ i] = blk.attention1.attention.attention_weights # “编码器-解码器”自注意力权重 self._attention_weights[1][ i] = blk.attention2.attention.attention_weights return self.dense(X), state @property def attention_weights(self): return self._attention_weights(5)训练
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10 lr, num_epochs, device = 0.005, 200, d2l.try_gpu() ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4 key_size, query_size, value_size = 32, 32, 32 norm_shape = [32] train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = TransformerEncoder( len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout) decoder = TransformerDecoder( len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .'] fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .'] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, device, True) print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}') enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads, -1, num_steps)) enc_attention_weights.shape enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads, -1, num_steps)) enc_attention_weights.shape dec_attention_weights_2d = [head[0].tolist() for step in dec_attention_weight_seq for attn in step for blk in attn for head in blk] dec_attention_weights_filled = torch.tensor( pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values) dec_attention_weights = dec_attention_weights_filled.reshape((-1, 2, num_layers, num_heads, num_steps)) dec_self_attention_weights, dec_inter_attention_weights = \ dec_attention_weights.permute(1, 2, 3, 0, 4) dec_self_attention_weights.shape, dec_inter_attention_weights.shape # Plusonetoincludethebeginning-of-sequencetoken d2l.show_heatmaps( dec_self_attention_weights[:, :, :, :len(translation.split()) + 1], xlabel='Key positions', ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5)) d2l.show_heatmaps( dec_inter_attention_weights, xlabel='Key positions', ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
更多推荐












所有评论(0)