Transformer核心架构详解:自注意力与多头注意力 | 吴恩达2025最新课程笔记

本文深度剖析Transformer的核心机制——自注意力(Self-Attention)和多头注意力(Multi-Head Attention)。通过数学推导、可视化图表和PyTorch代码实现,详细讲解QKV矩阵计算、注意力分数、缩放点积注意力等关键技术。涵盖Transformer Block完整结构、残差连接、层归一化等工程实践要点,是理解现代大语言模型架构的必读教程。

一、Transformer架构全景

1.1 Transformer诞生的背景

2017年Google发表的《Attention Is All You Need》彻底改变了NLP领域。其核心创新:完全抛弃RNN/CNN,只用注意力机制

Transformer革命

传统方法

2017年突破

RNN/LSTM

串行计算
梯度消失
难以并行

Self-Attention

并行计算
长距离依赖
可扩展性强

1.2 完整架构图

解码器 N×

编码器 N×

输入嵌入

位置编码

多头注意力

Add & Norm

前馈网络

Add & Norm

输出嵌入

位置编码

Masked
多头注意力

Add & Norm

交叉注意力

Add & Norm

前馈网络

Add & Norm

线性层

Softmax

输出概率

1.3 三大核心组件

组件 作用 关键技术
Self-Attention 捕捉序列内部依赖 QKV矩阵、缩放点积
Multi-Head Attention 多视角信息提取 多个注意力头并行
Feed-Forward Network 非线性变换 两层全连接+激活

二、自注意力机制(Self-Attention)详解

2.1 核心思想

Self-Attention的本质:让序列中的每个元素都能"看到"其他所有元素,并决定关注程度。

输入句子

0.2

0.5

0.3

自注意力计算

我(I)

爱(love)

AI

增强表示:
爱' = 0.2×我 + 0.5×爱 + 0.3×AI

2.2 QKV三剑客:Query、Key、Value

核心概念:

  • Query (Q): “我要找什么?”(查询向量)
  • Key (K): “我是什么?”(键向量)
  • Value (V): “我有什么信息?”(值向量)

类比理解:就像图书馆检索系统

图书内容 图书索引 你的查询 图书内容 图书索引 你的查询 《深度学习》: 0.9 《数据结构》: 0.2 《Python》: 0.6 Query: "机器学习" 计算相关度(Key匹配) 按权重提取Value 返回加权信息

2.3 数学公式推导

Step 1: 生成QKV矩阵

对于输入序列 X∈Rn×dX \in \mathbb{R}^{n \times d}XRn×d (n个词,每个d维):

Q=XWQ,WQ∈Rd×dkK=XWK,WK∈Rd×dkV=XWV,WV∈Rd×dv \begin{align} Q &= XW^Q, \quad W^Q \in \mathbb{R}^{d \times d_k} \\ K &= XW^K, \quad W^K \in \mathbb{R}^{d \times d_k} \\ V &= XW^V, \quad W^V \in \mathbb{R}^{d \times d_v} \end{align} QKV=XWQ,WQRd×dk=XWK,WKRd×dk=XWV,WVRd×dv

输入矩阵 X
[n × d]

权重 W^Q
[d × d_k]

权重 W^K
[d × d_k]

权重 W^V
[d × d_v]

Query
[n × d_k]

Key
[n × d_k]

Value
[n × d_v]

Step 2: 计算注意力分数

Score=QKTdk \text{Score} = \frac{QK^T}{\sqrt{d_k}} Score=dk QKT

为什么要除以 dk\sqrt{d_k}dk ?

问题: 维度d_k增大时
点积值会急剧增长

导致Softmax梯度消失

解决: 缩放因子 √d_k
保持数值稳定

Step 3: Softmax归一化

Attention Weights=softmax(Score)=softmax(QKTdk) \text{Attention Weights} = \text{softmax}(\text{Score}) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) Attention Weights=softmax(Score)=softmax(dk QKT)

Step 4: 加权求和

Output=Attention Weights×V \text{Output} = \text{Attention Weights} \times V Output=Attention Weights×V

完整公式:

Attention(Q,K,V)=softmax(QKTdk)V \boxed{\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V} Attention(Q,K,V)=softmax(dk QKT)V

2.4 直观示例:计算"爱"的新表示

输入句子:“我 爱 学习 AI”

Step 4: 加权求和

Step 2-3: 计算注意力

Step 1: 生成QKV

输入:
我 爱 学习 AI

Q矩阵
(查询)

K矩阵
(键)

V矩阵
(值)

QK^T / √d_k

Softmax归一化

权重:
我:0.1 爱:0.2
学习:0.4 AI:0.3

新表示:
爱' = Σ(权重×值)

假设计算"爱"对其他词的注意力:

Query·Key Score Softmax 最终贡献
2.1 2.1/√64=0.26 0.1 0.1×V(我)
3.5 0.44 0.2 0.2×V(爱)
学习 7.2 0.90 0.4 0.4×V(学习)
AI 5.8 0.73 0.3 0.3×V(AI)

最终:"爱"的新表示 = 0.1×V(我) + 0.2×V(爱) + 0.4×V(学习) + 0.3×V(AI)


三、多头注意力(Multi-Head Attention)

3.1 为什么需要多头?

单头的局限:只能捕捉一种关系模式

多头注意力

我爱学习AI

Head1
语义关系

Head2
语法关系

Head3
位置关系

Head4
...

单头注意力

我爱学习AI

只关注
语义相关性

3.2 多头注意力架构

输入 X

线性变换并分割

Head 1
Attention(Q1,K1,V1)

Head 2
Attention(Q2,K2,V2)

Head 3
Attention(Q3,K3,V3)

Head h
Attention(Qh,Kh,Vh)

拼接 Concat

线性变换 W^O

输出

3.3 数学公式

对于 hhh 个注意力头:

headi=Attention(QWiQ,KWiK,VWiV)MultiHead(Q,K,V)=Concat(head1,...,headh)WO \begin{align} \text{head}_i &= \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \\ \text{MultiHead}(Q,K,V) &= \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O \end{align} headiMultiHead(Q,K,V)=Attention(QWiQ,KWiK,VWiV)=Concat(head1,...,headh)WO

其中:

  • WiQ∈Rdmodel×dkW_i^Q \in \mathbb{R}^{d_{model} \times d_k}WiQRdmodel×dk
  • WiK∈Rdmodel×dkW_i^K \in \mathbb{R}^{d_{model} \times d_k}WiKRdmodel×dk
  • WiV∈Rdmodel×dvW_i^V \in \mathbb{R}^{d_{model} \times d_v}WiVRdmodel×dv
  • WO∈Rhdv×dmodelW^O \in \mathbb{R}^{hd_v \times d_{model}}WORhdv×dmodel

典型配置(如BERT):

  • dmodel=768d_{model} = 768dmodel=768 (模型维度)
  • h=12h = 12h=12 (注意力头数)
  • dk=dv=dmodel/h=64d_k = d_v = d_{model}/h = 64dk=dv=dmodel/h=64 (每个头的维度)

3.4 多头的优势

不同注意力头的关注点

句子: The animal didn't cross the street because it was too tired

it 指代什么?

Head 1
it → animal
(主语指代)

Head 2
it → street
(位置关系)

Head 3
didn't → tired
(因果关系)

综合判断:
it = animal


四、PyTorch完整实现

4.1 Scaled Dot-Product Attention

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    参数:
        Q: [batch_size, n_heads, seq_len, d_k]
        K: [batch_size, n_heads, seq_len, d_k]
        V: [batch_size, n_heads, seq_len, d_v]
        mask: [batch_size, 1, 1, seq_len] 可选
    返回:
        output: [batch_size, n_heads, seq_len, d_v]
        attention_weights: [batch_size, n_heads, seq_len, seq_len]
    """
    d_k = Q.size(-1)
    
    # 1. 计算注意力分数: QK^T / √d_k
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    # scores: [batch_size, n_heads, seq_len, seq_len]
    
    # 2. 可选:应用mask(用于Decoder中的自回归)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # 3. Softmax归一化
    attention_weights = F.softmax(scores, dim=-1)
    
    # 4. 加权求和
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights


# 测试
batch_size, n_heads, seq_len, d_k = 2, 8, 10, 64
Q = torch.randn(batch_size, n_heads, seq_len, d_k)
K = torch.randn(batch_size, n_heads, seq_len, d_k)
V = torch.randn(batch_size, n_heads, seq_len, d_k)

output, attn_weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状: {output.shape}")  # torch.Size([2, 8, 10, 64])
print(f"注意力权重: {attn_weights.shape}")  # torch.Size([2, 8, 10, 10])

4.2 Multi-Head Attention完整实现

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        """
        参数:
            d_model: 模型维度(如768)
            n_heads: 注意力头数(如12)
        """
        super().__init__()
        assert d_model % n_heads == 0, "d_model必须能被n_heads整除"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # 每个头的维度
        
        # QKV的线性变换矩阵
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        
        # 输出的线性变换
        self.W_O = nn.Linear(d_model, d_model)
        
    def split_heads(self, x):
        """
        将输入分割成多个头
        x: [batch_size, seq_len, d_model]
        返回: [batch_size, n_heads, seq_len, d_k]
        """
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
    
    def forward(self, Q, K, V, mask=None):
        """
        参数:
            Q, K, V: [batch_size, seq_len, d_model]
            mask: [batch_size, 1, 1, seq_len]
        """
        batch_size = Q.size(0)
        
        # 1. 线性变换
        Q = self.W_Q(Q)  # [batch_size, seq_len, d_model]
        K = self.W_K(K)
        V = self.W_V(V)
        
        # 2. 分割成多个头
        Q = self.split_heads(Q)  # [batch_size, n_heads, seq_len, d_k]
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # 3. 缩放点积注意力
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        # attn_output: [batch_size, n_heads, seq_len, d_k]
        
        # 4. 合并多个头
        attn_output = attn_output.transpose(1, 2).contiguous()
        # [batch_size, seq_len, n_heads, d_k]
        attn_output = attn_output.view(batch_size, -1, self.d_model)
        # [batch_size, seq_len, d_model]
        
        # 5. 最终线性变换
        output = self.W_O(attn_output)
        
        return output, attn_weights


# 使用示例
d_model = 512
n_heads = 8
batch_size = 2
seq_len = 10

mha = MultiHeadAttention(d_model, n_heads)

# 输入
x = torch.randn(batch_size, seq_len, d_model)

# 自注意力:Q=K=V
output, attn_weights = mha(x, x, x)

print(f"输入形状: {x.shape}")              # torch.Size([2, 10, 512])
print(f"输出形状: {output.shape}")         # torch.Size([2, 10, 512])
print(f"注意力权重: {attn_weights.shape}") # torch.Size([2, 8, 10, 10])

4.3 可视化注意力权重

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attn_weights, tokens, head_idx=0):
    """
    可视化某个注意力头的权重
    attn_weights: [batch_size, n_heads, seq_len, seq_len]
    tokens: 词列表
    head_idx: 要可视化的头索引
    """
    # 提取第一个样本的指定头
    attn = attn_weights[0, head_idx].detach().cpu().numpy()
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(attn, 
                xticklabels=tokens,
                yticklabels=tokens,
                cmap='YlOrRd',
                annot=True,
                fmt='.2f',
                cbar_kws={'label': '注意力权重'})
    plt.title(f'Attention Head {head_idx}')
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.tight_layout()
    plt.show()


# 示例:可视化英译中的注意力
tokens = ['I', 'love', 'learning', 'AI', '<EOS>']
seq_len = len(tokens)

# 模拟注意力权重
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(1, seq_len, 512)
_, attn_weights = mha(x, x, x)

# 可视化第0个头
visualize_attention(attn_weights, tokens, head_idx=0)

输出效果:

注意力矩阵 (Head 0)
         I    love  learning  AI   <EOS>
I      0.20  0.15    0.10   0.35  0.20
love   0.10  0.50    0.30   0.05  0.05
learning 0.05 0.30    0.40   0.20  0.05
AI     0.15  0.10    0.25   0.45  0.05
<EOS>  0.05  0.05    0.05   0.10  0.75

五、Transformer Block完整结构

5.1 单个Block的组成

残差连接

残差连接

输入 X

多头注意力

Add

Layer Norm

前馈网络
2层全连接

Add

Layer Norm

输出 Y

5.2 残差连接(Residual Connection)

为什么需要?

ResNet思想

深层网络问题

引入残差

梯度消失

难以训练

Y = X + F(X)

梯度直通
易于优化

数学表示:
Output=LayerNorm(X+MultiHeadAttention(X)) \text{Output} = \text{LayerNorm}(X + \text{MultiHeadAttention}(X)) Output=LayerNorm(X+MultiHeadAttention(X))

5.3 Layer Normalization

与Batch Norm的区别:

特性 Batch Norm Layer Norm
归一化维度 跨batch维度 跨特征维度
适用场景 CNN(固定batch) NLP(变长序列)
依赖性 依赖batch大小 独立于batch

Layer Normalization

样本1
[d1,d2,d3]

对每个样本
跨特征归一化

Batch Normalization

样本1
[d1,d2,d3]

对每个特征
跨样本归一化

样本2
[d1,d2,d3]

样本3
[d1,d2,d3]

Layer Norm公式:
LayerNorm(x)=γx−μσ2+ϵ+β \text{LayerNorm}(x) = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta LayerNorm(x)=γσ2+ϵ xμ+β

其中 μ,σ\mu, \sigmaμ,σ 是当前层所有特征的均值和标准差。

5.4 前馈网络(Feed-Forward Network)

结构:两层全连接+激活函数

FFN(x)=ReLU(xW1+b1)W2+b2 \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2 FFN(x)=ReLU(xW1+b1)W2+b2

输入
[seq_len, d_model]

全连接1
d_model → d_ff

ReLU激活

全连接2
d_ff → d_model

输出
[seq_len, d_model]

典型配置:

  • BERT: dmodel=768d_{model}=768dmodel=768, dff=3072d_{ff}=3072dff=3072 (4倍)
  • GPT-3: dmodel=12288d_{model}=12288dmodel=12288, dff=49152d_{ff}=49152dff=49152 (4倍)

5.5 完整Transformer Block实现

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        """
        参数:
            d_model: 模型维度
            n_heads: 注意力头数
            d_ff: 前馈网络隐藏层维度
            dropout: Dropout比率
        """
        super().__init__()
        
        # 多头注意力
        self.mha = MultiHeadAttention(d_model, n_heads)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer Normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        x: [batch_size, seq_len, d_model]
        """
        # 1. 多头自注意力 + 残差连接 + LayerNorm
        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout1(attn_output)
        x = self.norm1(x + attn_output)  # 残差连接
        
        # 2. 前馈网络 + 残差连接 + LayerNorm
        ffn_output = self.ffn(x)
        ffn_output = self.dropout2(ffn_output)
        x = self.norm2(x + ffn_output)   # 残差连接
        
        return x


# 测试
d_model = 512
n_heads = 8
d_ff = 2048
block = TransformerBlock(d_model, n_heads, d_ff)

x = torch.randn(2, 10, d_model)  # [batch, seq_len, d_model]
output = block(x)

print(f"输入形状: {x.shape}")      # torch.Size([2, 10, 512])
print(f"输出形状: {output.shape}")  # torch.Size([2, 10, 512])

六、位置编码(Positional Encoding)

6.1 为什么需要位置编码?

问题:自注意力是置换不变的(permutation-invariant)

sentence1 = "我 爱 AI"
sentence2 = "AI 爱 我"

# 如果没有位置编码,Self-Attention会给出相同的结果!

问题: 自注意力无法区分词序

解决: 添加位置信息

方案1: 学习位置嵌入
(BERT)

方案2: 固定位置编码
(原始Transformer)

6.2 正弦位置编码

公式:
PE(pos,2i)=sin⁡(pos100002i/dmodel)PE(pos,2i+1)=cos⁡(pos100002i/dmodel) \begin{align} PE_{(pos, 2i)} &= \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \\ PE_{(pos, 2i+1)} &= \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) \end{align} PE(pos,2i)PE(pos,2i+1)=sin(100002i/dmodelpos)=cos(100002i/dmodelpos)

其中:

  • pospospos: 词的位置(0, 1, 2, …)
  • iii: 维度索引(0, 1, …, d_model/2)

优点:

  1. ✅ 可以处理任意长度的序列
  2. ✅ 不需要训练参数
  3. ✅ 相对位置关系可以通过线性变换表达

6.3 实现代码

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # 计算分母
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                            (-math.log(10000.0) / d_model))
        
        # 计算正弦和余弦
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度
        
        # 添加batch维度
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        x: [batch_size, seq_len, d_model]
        """
        # 添加位置编码
        x = x + self.pe[:, :x.size(1), :]
        return x


# 可视化位置编码
def visualize_positional_encoding(d_model=128, max_len=100):
    pe = PositionalEncoding(d_model, max_len)
    encoding = pe.pe[0, :max_len, :].numpy()
    
    plt.figure(figsize=(15, 5))
    plt.imshow(encoding.T, aspect='auto', cmap='RdBu', 
               interpolation='nearest')
    plt.colorbar(label='编码值')
    plt.xlabel('位置')
    plt.ylabel('维度')
    plt.title('正弦位置编码可视化')
    plt.tight_layout()
    plt.show()

visualize_positional_encoding()

输出效果:会看到规律的波浪状图案,不同频率编码不同维度。


七、完整Encoder实现

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, d_ff, 
                 n_layers, max_len=5000, dropout=0.1):
        """
        参数:
            vocab_size: 词汇表大小
            d_model: 模型维度
            n_heads: 注意力头数
            d_ff: 前馈网络维度
            n_layers: Transformer Block层数
            max_len: 最大序列长度
            dropout: Dropout比率
        """
        super().__init__()
        
        # 词嵌入
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # 位置编码
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        # 多层Transformer Block
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        x: [batch_size, seq_len] (token indices)
        """
        # 1. 词嵌入 + 位置编码
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # 2. 通过多层Transformer Block
        for layer in self.layers:
            x = layer(x, mask)
        
        return x


# 使用示例:构建一个小型BERT
vocab_size = 30000
d_model = 768
n_heads = 12
d_ff = 3072
n_layers = 12

encoder = TransformerEncoder(vocab_size, d_model, n_heads, d_ff, n_layers)

# 输入token indices
input_ids = torch.randint(0, vocab_size, (2, 20))  # [batch=2, seq_len=20]
output = encoder(input_ids)

print(f"输入形状: {input_ids.shape}")  # torch.Size([2, 20])
print(f"输出形状: {output.shape}")     # torch.Size([2, 20, 768])
print(f"参数量: {sum(p.numel() for p in encoder.parameters())/1e6:.1f}M")
# 输出: 参数量: 110.1M (接近BERT-Base的110M)

八、关键概念总结

8.1 公式总结

组件 公式 说明
Self-Attention Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dk QKT)V 缩放点积注意力
Multi-Head Concat(head1,...,headh)WO\text{Concat}(\text{head}_1,...,\text{head}_h)W^OConcat(head1,...,headh)WO 多视角融合
FFN ReLU(xW1+b1)W2+b2\text{ReLU}(xW_1+b_1)W_2+b_2ReLU(xW1+b1)W2+b2 两层全连接
Layer Norm γx−μσ2+ϵ+β\gamma\frac{x-\mu}{\sqrt{\sigma^2+\epsilon}}+\betaγσ2+ϵ xμ+β 特征归一化
Positional sin⁡(pos100002i/d)\sin(\frac{pos}{10000^{2i/d}})sin(100002i/dpos) 位置编码

8.2 配置对比

模型 d_model n_heads n_layers d_ff 参数量
BERT-Base 768 12 12 3072 110M
BERT-Large 1024 16 24 4096 340M
GPT-2 768 12 12 3072 117M
GPT-3 12288 96 96 49152 175B

8.3 架构流程图

每个Block内部

多头注意力

Add & Norm

前馈网络

Add & Norm

输入Token IDs

词嵌入层

位置编码

Transformer Block 1

Transformer Block 2

... Block N

输出表示


九、实战练习

练习1:计算注意力权重

题目:给定QKV矩阵,手工计算注意力输出

# 已知(简化为2x2矩阵方便计算)
Q = torch.tensor([[1.0, 0.0],
                  [0.0, 1.0]])

K = torch.tensor([[1.0, 0.0],
                  [0.5, 0.5]])

V = torch.tensor([[2.0, 0.0],
                  [1.0, 1.0]])

# 步骤:
# 1. 计算 QK^T / √d_k
# 2. Softmax
# 3. 乘以V

# 你的答案:

练习2:实现Masked Self-Attention

任务:修改Self-Attention,实现Decoder中的mask机制(当前词不能看到未来词)

def masked_self_attention(Q, K, V):
    """
    TODO: 实现masked attention
    提示: 使用torch.tril创建下三角mask
    """
    pass

练习3:分析注意力头

任务:加载预训练BERT,可视化不同注意力头关注的模式

from transformers import BertModel, BertTokenizer

model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

text = "The animal didn't cross the street because it was too tired."
inputs = tokenizer(text, return_tensors='pt')

with torch.no_grad():
    outputs = model(**inputs)
    attentions = outputs.attentions  # 12层,每层12个头


Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐