第6章 自注意力机制--现代AI基石
本文深入解析了自注意力机制这一现代AI基石技术。文章首先介绍了自注意力机制处理的三种典型序列任务类型:1)输入输出数量相同的任务(如词性标注);2)序列到标签的任务(如情感分析);3)序列到序列的任务(如机器翻译)。随后详细阐述了自注意力机制的工作原理,包括Query-Key-Value三元组的生成、注意力分数计算和加权求和过程,并通过"编辑部审稿"的生动比喻帮助理解。文章还重
目录
各位读者,欢迎来到我们技术解析的新篇章!在之前的章节中,我们了解了深度学习的基础构件。今天,我们将一起揭开一个被誉为“现代AI基石”的技术——自注意力机制 的神秘面纱。
正是这个机制,驱动了像ChatGPT、BERT这样的革命性模型,让机器在理解语言、图像甚至蛋白质结构方面取得了前所未有的突破。它解决了传统模型在处理长距离依赖关系时的痛点。别担心它听起来很复杂,本章的目标就是让它变得像“拼乐高”一样清晰明了。
6.1 输入是向量序列的情况
在深入自注意力本身之前,我们必须先理解它处理的数据形式。在自然语言处理(NLP)和许多其他任务中,我们面对的往往不是单个的、孤立的数据点,而是一个序列。
什么是序列?
一个序列就是一组按特定顺序排列的数据。最典型的例子就是一句话:
“我 爱 吃 苹果”
这句话中的四个词“我”、“爱”、“吃”、“苹果”就构成了一个序列。它们的顺序至关重要,“苹果吃我爱”就变得毫无意义。
如何将序列输入给模型?
计算机无法直接理解文字,所以我们需要先将每个词转换成一个数值向量,这个过程叫做词嵌入。我们可以把每个词向量想象成这个词的“身份证”,里面编码了它的语义信息。
于是,一句话就变成了一个向量序列:[向量“我”, 向量“爱”, 向量“吃”, 向量“苹果”]
现在,关键问题来了:我们如何处理这个向量序列?根据任务的不同,主要有以下三种经典的场景,理解了它们,你就能明白自注意力机制大展身手的舞台在哪里。
6.1.1 类型 1:输入与输出数量相同
在这种任务中,我们为序列中的每一个输入元素,都生成一个对应的输出。
典型任务:词性标注、命名实体识别
-
词性标注:给句子中的每个词打上词性标签(名词、动词、形容词等)。
-
输入:
[我, 爱, 吃, 苹果] -
输出:
[代词, 动词, 动词, 名词]
-
-
命名实体识别:识别句子中具有特定意义的实体(人名、地名、机构名等)。
-
输入:
[马云, 在, 杭州, 创建了, 阿里巴巴] -
输出:
[人名, 非实体, 地名, 非实体, 机构名]
-
模型是如何思考的?
模型在处理“苹果”这个词时,不能只看“苹果”本身。它需要结合上下文来判断:
-
在“我吃苹果”里,“苹果”是水果,是名词。
-
在“苹果手机很贵”里,“苹果”是公司,可能被标注为机构名。
因此,模型需要一种机制,让序列中的每个词都能“感知”到其他所有词的信息,从而做出最准确的判断。这就像是编辑部在审阅一篇文章,每个编辑(输出)在评判某个段落(输入词)时,都需要通读全文,了解上下文。
6.1.2 类型 2:输入是一个序列,输出是一个标签
这种任务可以看作是“整体判断”,我们将整个序列压缩、汇总,最终得出一个总的结论。
典型任务:文本分类、情感分析
-
情感分析:判断一段影评是积极的还是消极的。
-
输入:
[这, 部, 电影, 简直, 太, 精彩, 了](一个序列) -
输出:
“积极”(一个标签)
-
-
垃圾邮件识别:判断一封邮件是否是垃圾邮件。
-
输入:整封邮件内容的词序列。
-
输出:
“是”或“否”
-
模型是如何思考的?
模型需要从整个句子中捕捉到那个决定整体性质的“关键信号”。在上面的影评例子中,“精彩”这个词无疑是一个强烈的积极信号,但“简直”和“太”这两个词起到了加强语气的作用,使得“精彩”的程度更深。模型需要权衡所有词的重要性,并最终将所有信息“融合”成一个代表整体情感的判断。
这好比是评审团在看一部电影,他们看完整部影片(输入序列)后,经过内部讨论,最终举牌给出一个综合评分(输出标签)。
6.1.3 类型 3:序列到序列任务
这是最复杂也最有趣的一种情况,输入和输出都是序列,但它们的长度和内容可能完全不同。
典型任务:机器翻译、文本摘要
-
机器翻译:
-
输入(中文):
[我, 爱, 你] -
输出(英文):
[I, love, you]
-
-
文本摘要:
-
输入:一篇长文章(长序列)。
-
输出:一段简短摘要(短序列)。
-
模型是如何思考的?
这类任务通常需要一个编码器-解码器 架构。
-
编码器:负责“理解”输入序列。它像是一个读者,仔细阅读完整篇原文,并将其核心含义和关键信息提取并压缩成一个复杂的“上下文向量”(可以理解为文章的“思想精华”)。
-
解码器:负责“生成”输出序列。它像是一个作者,根据编码器提供的“思想精华”,一个词一个词地构造出目标语言的句子或摘要。
在这个过程中,无论是编码器理解原文,还是解码器根据“精华”生成新词,都需要强大的机制来管理不同词之间的关系。比如在翻译时,解码器生成“love”这个词时,必须高度关注输入序列中的“爱”,同时也要参考“我”和“你”来确保语法正确。
小结
无论输入和输出的形式如何变化,这些任务都面临一个共同的、根本性的挑战:如何有效地让序列中的元素彼此交互,捕捉它们之间复杂且长距离的依赖关系。 传统的循环神经网络(RNN)和卷积神经网络(CNN)在处理这个问题时各有短板,而这,正是自注意力机制横空出世,一举成为王者的原因。
6.2 自注意力机制的运作原理
现在,让我们进入最核心的部分。请放心,我们会用一个最生动的比喻,一步步拆解它。
编辑部审稿比喻
想象你是一家科技杂志的编辑,今天收到了一篇关于“人工智能”的文章。你的任务是审阅这篇文章,并对文中的每一个句子给出修改意见。
作为一个负责任的编辑,你不会孤立地看每个句子。你会:
-
通读全文,理解整篇文章在讲什么。
-
在评审某个特定句子时,你会思考:“这个句子和文章里的其他哪些句子关系最紧密?哪些句子能帮助我更好地理解它?”
自注意力机制所做的,就是把这个过程自动化、数学化。它让序列中的每个词(句子)都能“注意”到序列中所有其他的词(句子),从而获得一个更丰富的上下文表示。
三步拆解自注意力
假设我们的输入序列是:“Thinking Machines”。为了简化,我们已经有它们的词向量表示 x1 (Thinking) 和 x2 (Machines)。
第一步:创建三大角色——Query, Key, Value
对于输入序列中的每一个词向量,自注意力机制会生成三个新的向量:
-
Query(查询向量):可以理解为“当前词”提出的一个问题:“我是谁?我应该关注谁?”
-
Key(键向量):可以理解为“每个词”(包括自己)身上携带的一个“标签”或“身份证”,用来回应Query的查询。
-
Value(值向量):可以理解为“每个词”所代表的“真实信息”或“内涵”。
如何生成?很简单,通过矩阵乘法。我们有三组可学习的权重矩阵:WQ, WK, WV。
-
q1 = x1 * WQ(Thinking的Query) -
k1 = x1 * WK(Thinking的Key) -
v1 = x1 * WV(Thinking的Value) -
q2 = x2 * WQ(Machines的Query) -
k2 = x2 * WK... 以此类推。
第二步:计算注意力分数——谁和我最相关?
现在,我们聚焦于第一个词“Thinking”(q1)。我们想知道在整句话中,它应该与哪些词(包括自己)建立更强的联系。
我们用它的Query向量 q1 去和序列中所有词的Key向量(k1, k2)进行点积运算。点积的结果就是一个分数,它表示了“Thinking”与每个词之间的相关性强度。
-
分数1 =
q1 · k1(Thinking与自己的相关性) -
分数2 =
q1 · k2(Thinking与Machines的相关性)
第三步:加权求和,生成输出
-
标准化:将上一步得到的分数进行缩放(除以Key向量维度的平方根,为了梯度稳定)并通过Softmax函数,使得所有分数之和为1。这样,分数就变成了0到1之间的权重。权重越高,代表相关性越强。
-
假设Softmax后:权重1 = 0.6, 权重2 = 0.4
-
这意味着,在理解“Thinking”时,模型认为“Thinking”本身的信息占60%的重要性,“Machines”的信息占40%的重要性。
-
-
合成:将这些权重分别乘到对应词的Value向量上,然后加起来。
-
输出向量z1 = (权重1 * v1) + (权重2 * v2)
-
这个最终的 z1,就是“Thinking”这个词经过自注意力机制处理后的新表示!它不再是一个孤立的词向量,而是一个融入了全局上下文信息(特别是“Machines”的信息)的“增强版”词向量。
同理,我们再用 q2 去和所有 k 计算,就能得到“Machines”的新表示 z2。
为什么这个机制如此强大?
-
并行计算:所有词的Query, Key, Value都可以通过矩阵运算同时产生,所有输出
z也可以同时计算。这比RNN必须一步步计算快得多。 -
全局视野:每个词在计算新表示时,都直接“看到”了序列中的所有其他词,无论距离多远。它彻底解决了长距离依赖问题。在“The animal didn't cross the street because it was too tired”这个句子里,“it”可以轻松地直接关联到“animal”。
-
动态权重:注意力权重是动态计算的。在不同的上下文里,同一个词关注的重点会不同。比如“苹果”在“吃苹果”和“苹果公司”中,会与不同的词建立强连接。
至此,你已经理解了自注意力最核心的思想。它就像一个精妙的信息调配系统,让序列中的每个元素都能主动地、有选择地从同伴那里吸收养分,最终全面提升整个序列的表示质量。
6.3 多头自注意力
如果标准自注意力已经如此强大,为什么还需要“多头”呢?我们继续用编辑部的比喻。
一个顶尖的编辑部不会只有一位编辑。他们会组建一个专家小组:
-
语法专家:专注于句子的结构、语法是否正确。
-
逻辑专家:专注于段落之间的逻辑连贯性。
-
事实核查专家:专注于文中所提及的数据和事实是否准确。
当这篇文章同时经过这个专家小组的审阅时,每个专家都会从自己独特的视角提出修改意见。最终,主编会把这些不同角度的意见综合起来,形成一份最全面、最深刻的最终审稿报告。
多头自注意力做的就是这件事!
它不满足于只学习一种角度的关联关系。它通过多组不同的 WQ, WK, WV 矩阵,将模型投影到多个不同的“表示子空间”,让模型同时关注来自不同位置的不同信息。
具体流程:
-
并行多个“头”:假设我们设置8个头(
h=8)。那么对于输入x1,我们会得到:-
头1:
q1_1, k1_1, v1_1 -
头2:
q1_2, k1_2, v1_2 -
...
-
头8:
q1_8, k1_8, v1_8
每个头都有自己的权重矩阵,因此它们会独立地计算注意力,得到8个不同的输出向量z1_1, z1_2, ..., z1_8。
-
-
拼接与线性变换:将这8个输出向量拼接成一个很长的向量。
-
拼接后的向量 = Concat(z1_1, z1_2, ..., z1_8)
然后,再通过一个可学习的线性变换矩阵WO,将这个长向量映射到我们想要的输出维度。
-
-
得到最终输出:经过线性变换后,我们得到了“Thinking”这个词的最终多头自注意力输出。
这样做的优势是什么?
-
增强模型表达能力:不同的头可以学习到不同类型的依赖关系。例如,在处理一个句子时,有些头可能专门关注“指代关系”(如it指向谁),有些头可能专门关注“动词-宾语”关系,还有些头可能专门关注“并列结构”。
-
提供更丰富的表示:最终每个词的表示,是从多个语义子空间汇总而来的信息,比单一注意力头提供的表示更加细腻和强大。
可以说,多头自注意力是Transformer模型性能卓越的关键设计之一,它让模型拥有了像“八爪鱼”一样同时捕捉多种复杂模式的能力。
6.4 位置编码
细心的你可能已经发现了一个问题:自注意力机制是置换不变的。也就是说,它把输入序列看作一个集合,而不是一个序列。
举个例子:
-
序列 A: “我爱妈妈”
-
序列 B: “妈妈爱我”
对于自注意力机制来说,这两个序列的输入是完全一样的(都是“我”、“爱”、“妈妈”这三个词),如果不加任何处理,它计算出的输出表示也会是完全一样的!但这显然是两个意思完全不同的句子。
顺序,是序列的灵魂!
为了解决这个问题,Transformer的设计者引入了位置编码。它的任务非常简单:为每个词的位置信息进行编码,并把这个信息加入到最初的词向量中。
它是如何工作的?
-
对于序列中的每个位置(例如,第一个词,第二个词...),我们生成一个独一无二的、与词向量同维度的位置向量。
-
将这个位置向量与对应的词向量相加,然后再输入到自注意力层。
新输入 = 词嵌入向量 + 位置编码向量
这样一来,模型在初始阶段就能知道“我”在第一位,“爱”在第二位,“妈妈”在第三位。它不再是一个无序的集合,而是一个有序的序列了。
这个位置向量长什么样?
原始Transformer论文使用了一种非常巧妙的正弦和余弦函数来生成位置编码。它之所以被广泛使用,是因为它有两个很好的性质:
-
能够编码绝对位置:每个位置都有独一无二的编码。
-
能够自然地 extrapolate 到更长的序列:由于正弦余弦函数的周期性,模型学到的位置关系可以一定程度上推广到在训练时未见过的更长序列上。
当然,现在也有很多模型使用可学习的位置编码,即把位置编码也作为模型参数,让模型在训练中自己学会什么位置该是什么样子。
无论如何,位置编码就像一个“定位器”,为失去了顺序感知的自注意力机制装上了“GPS”,确保了词序信息不被丢失。
6.5 截断自注意力
自注意力机制有一个显著的缺点:计算复杂度随序列长度呈平方级增长。
计算注意力时,需要生成一个 N x N 的注意力分数矩阵(N是序列长度)。如果序列长度是100,需要计算10000个分数;如果长度是1000,就需要计算1000000个分数!这在对长文档、长视频进行建模时,会带来巨大的计算和内存开销。
截断自注意力就是一种应对策略。 其核心思想是:一个词不一定需要关注整个序列的所有词,只需要关注一个局部窗口内的词就够了。
这就像我们读一篇很长的论文时,理解某一个段落,主要看它的前后几段就够了,不需要时时刻刻把全文500页都放在眼前。
具体实现:
设定一个固定的窗口大小 k。对于序列中的每个词,它只与左右各 k/2 个词(或者只与前面的 k 个词,在解码器中)计算注意力。
-
优点:
-
计算复杂度从
O(N²)降为O(N * k),k是一个常数。 -
大大减少了内存占用。
-
-
缺点:
-
牺牲了全局视野。如果某个长距离依赖关系恰好落在窗口之外,模型就无法捕捉到它。
-
因此,截断自注意力是效率与效果之间的一种权衡,通常在处理极长序列(如长达数千字的文档)时会被采用。像Longformer、BigBird等模型就使用了类似的思想,并设计了更复杂的稀疏注意力模式来兼顾效率和全局信息。
6.6 对比自注意力与卷积神经网络
CNN是计算机视觉领域的王者,它的核心是卷积核。让我们来看看这两位“大神”的异同。
| 特性 | 卷积神经网络 | 自注意力机制 |
|---|---|---|
| 感受野 | 局部起步。一个小卷积核(如3x3)只看到图像的局部区域。要通过堆叠很多层,信息才能逐步传递,最终让深层神经元拥有“全局感受野”。 | 天生全局。在第一层,每个词就能直接和序列中所有其他词交互,拥有完整的全局感受野。 |
| 权重 | 静态/不变的。一个训练好的3x3卷积核,无论扫描到图像的哪个部位,其权重参数都是固定不变的。 | 动态/依输入而变。注意力权重完全取决于当前输入的Query和Key,对于不同的输入序列,权重是动态计算的。 |
| 关系处理 | 内容无关。卷积核只关心位置关系(比如左上、正中等),不关心那个位置上具体是什么内容。 | 内容高度相关。它关注的是“语义”上的相关性。两个词即使离得很远,只要语义相关,也能产生高权重。 |
| 数据假设 | 基于平移不变性和局部性先验。认为一个在图片左上角学到的模式,在右下角也适用。 | 几乎没有强假设。它更灵活,让数据自己说话,从数据中学习所有的模式。 |
一个生动的比喻:
-
CNN 像一个拿着固定模板(卷积核)的质检员,在流水线上用同一个标准去检查每一个局部区域的产品。
-
自注意力 像一个在开会讨论的团队,每个人发言时,都会根据所有与会者(所有词)的具体内容,动态地决定该听取和采纳谁的意见。
互补与融合:
正因如此,现在很多研究开始探索将两者结合。比如在视觉领域,Vision Transformer (ViT) 用纯Transformer处理图像,而也有一些工作会在CNN中引入注意力机制,让卷积也变得“动态”起来。
6.7 对比自注意力与循环神经网络
在Transformer出现之前,RNN及其变体LSTM、GRU是处理序列任务的默认选择。让我们来一场“新旧王者的对决”。
| 特性 | 循环神经网络 | 自注意力机制 |
|---|---|---|
| 计算方式 | 顺序/串行。必须一个一个词地处理。计算第t个词的状态 h_t,必须等待第t-1个词的状态 h_{t-1} 算完。难以并行,训练慢。 |
并行。所有词的Query, Key, Value和输出都可以同时计算。高度并行,充分利用GPU,训练极快。 |
| 信息流动 | 一步一步传递。信息像接力棒一样,从一个时间步传递到下一个。距离越远,信息传递路径越长,容易造成梯度消失/爆炸,导致模型难以学习长距离依赖。 | 直接连接。任意两个词之间都是一步直达,无论它们相隔多远。完美解决了长距离依赖问题。 |
| 内存占用 | 计算时需要存储中间状态,内存占用与序列长度呈线性关系。 | 需要存储巨大的 N x N 注意力矩阵,内存占用与序列长度呈平方关系。这是自注意力的主要瓶颈。 |
| 解释性 | 隐藏状态 h_t 是一个混合了所有历史信息的黑箱,很难说清它具体记住了什么。 |
注意力权重矩阵 提供了一个清晰的、可视化的解释工具。我们可以直接看到模型在做决策时,关注了输入序列的哪些部分。 |
一个生动的比喻:
-
RNN 像一个有严重健忘症的读者,他必须用手指着文字,一个字一个字地读。每读一个新字,他对前面文字的记忆就会模糊一点。读长篇文章时,他很可能忘了开头讲了什么。
-
自注意力 像一个拥有摄影式记忆的天才读者。他能够瞬间扫视全文,并将所有文字及其关系印在脑中。在回答关于文中任何一个词的问题时,他都能瞬间回忆起全文的所有细节。
总结与展望
自注意力机制,凭借其强大的全局信息捕获能力、高度的并行性以及卓越的性能,已经深刻地改变了深度学习,特别是自然语言处理领域的格局。它并非完美(如计算复杂度高),但其设计思想无疑是开创性的。
从本章开始,你已经掌握了理解现代大模型(如GPT、BERT)最核心的原理。在接下来的章节中,我们将以此为基础,一步步搭建起完整的Transformer模型,并探索它在各个领域的惊艳应用。
以下是一个完整的Python示例程序,涵盖了第6章自注意力机制的核心内容。这个程序使用PyTorch实现,包含了自注意力、多头自注意力、位置编码等关键组件,并提供了可视化展示。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
class PositionalEncoding(nn.Module):
"""位置编码 - 使用正弦和余弦函数"""
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置使用正弦
pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置使用余弦
pe = pe.unsqueeze(0).transpose(0, 1) # 形状: [max_len, 1, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
# x: [seq_len, batch_size, d_model]
return x + self.pe[:x.size(0), :]
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力机制"""
def __init__(self, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
# query, key, value: [batch_size, seq_len, d_model]
d_k = query.size(-1) # 获取key的维度
# 计算注意力分数: Q * K^T / sqrt(d_k)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 应用mask(如果有)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 应用softmax得到注意力权重
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# 加权求和: 注意力权重 * Value
output = torch.matmul(attention_weights, value)
return output, attention_weights
class MultiHeadAttention(nn.Module):
"""多头自注意力机制"""
def __init__(self, d_model, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性变换层
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention(dropout)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 线性变换并分头
Q = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 应用注意力机制
attn_output, attn_weights = self.attention(Q, K, V, mask)
# 合并多头
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model)
# 输出线性变换
output = self.w_o(attn_output)
return output, attn_weights
class SimpleTransformerBlock(nn.Module):
"""简化的Transformer块(仅包含多头注意力和前馈网络)"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(SimpleTransformerBlock, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力 + 残差连接 + 层归一化
attn_output, attn_weights = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# 前馈网络 + 残差连接 + 层归一化
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x, attn_weights
def visualize_attention(attention_weights, words, title="注意力权重可视化"):
"""可视化注意力权重"""
plt.figure(figsize=(10, 8))
# 取第一个batch的第一个头的注意力权重
attn = attention_weights[0, 0].detach().numpy()
sns.heatmap(attn,
xticklabels=words,
yticklabels=words,
cmap="YlOrRd",
annot=True,
fmt=".2f",
cbar_kws={'label': '注意力权重'})
plt.title(title)
plt.xlabel("Key")
plt.ylabel("Query")
plt.tight_layout()
plt.show()
def visualize_positional_encoding():
"""可视化位置编码"""
d_model = 64
max_len = 50
pe = PositionalEncoding(d_model, max_len)
# 获取位置编码
positional_encoding = pe.pe.squeeze(1)[:max_len, :]
plt.figure(figsize=(12, 8))
plt.imshow(positional_encoding, cmap='coolwarm', aspect='auto')
plt.colorbar(label='位置编码值')
plt.title('位置编码可视化 (正弦余弦编码)')
plt.xlabel('维度索引')
plt.ylabel('位置索引')
plt.show()
def demonstrate_attention_patterns():
"""演示不同的注意力模式"""
print("=== 注意力模式演示 ===")
# 示例1: 局部注意力(类似CNN)
print("\n1. 局部注意力模式 (类似CNN):")
local_attn = torch.tensor([
[0.8, 0.2, 0.0, 0.0],
[0.3, 0.4, 0.3, 0.0],
[0.0, 0.3, 0.4, 0.3],
[0.0, 0.0, 0.2, 0.8]
])
print("每个词主要关注相邻的词")
# 示例2: 全局注意力(类似自注意力)
print("\n2. 全局注意力模式 (类似自注意力):")
global_attn = torch.tensor([
[0.4, 0.3, 0.2, 0.1],
[0.1, 0.4, 0.3, 0.2],
[0.2, 0.1, 0.4, 0.3],
[0.3, 0.2, 0.1, 0.4]
])
print("每个词都关注序列中的所有词")
# 示例3: 特定模式注意力
print("\n3. 特定模式注意力 (如句法结构):")
syntax_attn = torch.tensor([
[0.6, 0.3, 0.1, 0.0], # 主语关注动词
[0.1, 0.4, 0.4, 0.1], # 动词关注主语和宾语
[0.0, 0.3, 0.6, 0.1], # 宾语关注动词
[0.0, 0.1, 0.2, 0.7] # 修饰词关注被修饰词
])
print("注意力反映句法关系")
def compare_rnn_vs_attention():
"""比较RNN和自注意力机制"""
print("\n=== RNN vs 自注意力机制对比 ===")
seq_len = 5
d_model = 8
batch_size = 2
# 创建示例输入
x = torch.randn(batch_size, seq_len, d_model)
# RNN处理(顺序处理)
print("\n1. RNN处理方式:")
rnn = nn.RNN(d_model, d_model, batch_first=True)
rnn_output, _ = rnn(x)
print(f"RNN输出形状: {rnn_output.shape}")
print("RNN需要顺序处理,无法并行计算")
# 自注意力处理(并行处理)
print("\n2. 自注意力处理方式:")
self_attn = MultiHeadAttention(d_model, num_heads=2)
attn_output, attn_weights = self_attn(x, x, x)
print(f"自注意力输出形状: {attn_output.shape}")
print("自注意力可以并行计算所有位置")
# 计算复杂度对比
print("\n3. 计算复杂度对比:")
print(f"RNN复杂度: O(seq_len × d_model²)")
print(f"自注意力复杂度: O(seq_len² × d_model)")
print(f"当序列长度 {seq_len} 较小时,两者都可接受")
print(f"当序列长度很大时,自注意力的平方复杂度成为瓶颈")
def main():
"""主函数:演示自注意力机制的完整流程"""
print("第6章:自注意力机制完整演示")
print("=" * 50)
# 设置随机种子以便重现结果
torch.manual_seed(42)
# 1. 创建示例数据
print("\n1. 创建示例输入序列...")
# 假设我们有一个包含4个词的序列,每个词用8维向量表示
batch_size = 1
seq_len = 4
d_model = 8
# 模拟词嵌入
word_embeddings = torch.randn(batch_size, seq_len, d_model)
words = ["我", "爱", "自然", "语言"]
print(f"输入序列: {words}")
print(f"输入张量形状: {word_embeddings.shape}")
# 2. 演示位置编码
print("\n2. 应用位置编码...")
pos_encoder = PositionalEncoding(d_model)
# 调整输入形状以适应位置编码 [seq_len, batch_size, d_model]
input_sequence = word_embeddings.transpose(0, 1)
encoded_sequence = pos_encoder(input_sequence)
encoded_sequence = encoded_sequence.transpose(0, 1) # 恢复形状
print("位置编码已添加到词嵌入中")
# 3. 单头自注意力演示
print("\n3. 单头自注意力计算...")
single_head_attn = ScaledDotProductAttention()
# 使用相同的输入作为Q, K, V(自注意力)
output, attn_weights = single_head_attn(
encoded_sequence, encoded_sequence, encoded_sequence
)
print(f"自注意力输出形状: {output.shape}")
print("注意力权重矩阵:")
print(attn_weights.squeeze().detach().numpy())
# 可视化单头注意力
visualize_attention(attn_weights.unsqueeze(1), words, "单头自注意力权重")
# 4. 多头自注意力演示
print("\n4. 多头自注意力计算...")
num_heads = 2
multi_head_attn = MultiHeadAttention(d_model, num_heads)
multi_output, multi_attn_weights = multi_head_attn(
encoded_sequence, encoded_sequence, encoded_sequence
)
print(f"多头自注意力输出形状: {multi_output.shape}")
print(f"多头注意力权重形状: {multi_attn_weights.shape}") # [batch, heads, seq_len, seq_len]
# 可视化多头注意力(第一个头)
visualize_attention(multi_attn_weights, words, "多头自注意力 - 头1")
# 5. 完整的Transformer块演示
print("\n5. 完整的Transformer块处理...")
d_ff = 16 # 前馈网络隐藏层维度
transformer_block = SimpleTransformerBlock(d_model, num_heads, d_ff)
final_output, final_attn_weights = transformer_block(encoded_sequence)
print(f"Transformer块输出形状: {final_output.shape}")
# 6. 演示不同的序列处理类型
print("\n6. 序列处理类型演示:")
# 类型1: 输入输出数量相同(如词性标注)
print("\n 类型1 - 输入输出数量相同:")
print(" 应用: 词性标注、命名实体识别")
print(f" 输入: {len(words)}个词 → 输出: {len(words)}个标签")
# 类型2: 序列到标签(如情感分析)
print("\n 类型2 - 序列到单个标签:")
print(" 应用: 情感分析、文本分类")
print(f" 输入: {len(words)}个词 → 输出: 1个情感标签")
# 类型3: 序列到序列(如机器翻译)
print("\n 类型3 - 序列到序列:")
print(" 应用: 机器翻译、文本摘要")
print(f" 输入: {len(words)}个中文词 → 输出: {len(['I', 'love', 'nature', 'language'])}个英文词")
# 7. 可视化位置编码
print("\n7. 位置编码可视化...")
visualize_positional_encoding()
# 8. 演示不同的注意力模式
demonstrate_attention_patterns()
# 9. 比较RNN和自注意力
compare_rnn_vs_attention()
print("\n" + "=" * 50)
print("演示完成!")
print("\n关键要点总结:")
print("✓ 自注意力让序列中的每个元素都能直接关注所有其他元素")
print("✓ 多头注意力从不同子空间捕捉多种类型的关系")
print("✓ 位置编码为模型提供顺序信息")
print("✓ 相比RNN,自注意力支持并行计算且能更好处理长距离依赖")
print("✓ 相比CNN,自注意力具有全局感受野和动态权重")
if __name__ == "__main__":
main()
这个完整的Python程序涵盖了第6章自注意力机制的核心内容:
程序特点:
-
模块化实现:
-
PositionalEncoding: 正弦余弦位置编码 -
ScaledDotProductAttention: 缩放点积注意力 -
MultiHeadAttention: 多头自注意力 -
SimpleTransformerBlock: 简化的Transformer块
-
-
可视化功能:
-
注意力权重热力图
-
位置编码模式展示
-
不同注意力模式对比
-
-
全面演示:
-
三种序列处理类型
-
RNN与自注意力对比
-
多头注意力的实际效果
-
-
教育性注释:
-
详细的代码注释
-
运行时的解释输出
-
关键概念的总结
-
运行说明:
-
确保安装所需库:
pip install torch matplotlib seaborn numpy -
直接运行程序,可以看到:
-
注意力权重的可视化
-
不同机制的对比
-
完整的处理流程
-
-
这个程序通过实际代码和可视化,生动地展示了自注意力机制的工作原理和优势,完美总结了第6章的核心概念。
更多推荐


所有评论(0)