第5章 循环神经网络--让AI拥有记忆力的魔法(下)
LSTM(长短期记忆网络)通过引入细胞状态和门控机制(遗忘门、输入门、输出门)解决了传统RNN的"健忘症"问题。其核心是选择性记忆重要信息(如"法国"与"法语"的关联),并通过加性更新缓解梯度消失。相比RNN的梯度问题,LSTM采用梯度裁剪应对爆炸,通过结构革新(如GRU)解决消失。RNN家族支持多对一(情感分析)、多对多(词性标注)和序
5.5 LSTM 网络原理
在上一节,我们认识了RNN的基本结构,也知道了它有一个致命的“健忘症”——难以学习长距离的依赖关系。比如,在预测“我在法国长大,……,我能说一口流利的__?”这句话的最后一个词时,基本的RNN很可能已经忘记了最开头的“法国”,从而无法准确预测出“法语”。
那么,有没有一种更强大的“记忆增强型”RNN呢?有!这就是我们今天的主角——长短期记忆网络,简称 LSTM。它被设计出来的初衷,就是为了解决RNN的“健忘”问题。
你可以把基本的RNN想象成一个只有短期记忆的人,只能记住最近几秒钟发生的事情。而LSTM则像是一个专业的学者,他拥有一个笔记本(细胞状态) 和一套精密的管理流程(门控机制),知道哪些信息需要牢记,哪些需要忘记,哪些需要立刻使用。
LSTM的核心:传送带与三道门
LSTM的关键在于它有一条贯穿始终的“信息传送带”,也就是细胞状态。你可以把它想象成一条在时间上流淌的河流,信息可以非常轻松地从上游流到下游,几乎不受影响。LSTM的所有精巧设计,都围绕着如何在这条传送带上“添加”或“移除”信息。
那么,谁来控制信息的流入和流出呢?答案是三道神奇的门:
-
遗忘门:决定从细胞状态中丢弃什么信息。
-
输入门:决定哪些新信息会被存放到细胞状态中。
-
输出门:决定当前时刻要输出什么信息。
这三道门都不是简单的开关,而是由Sigmoid函数构成的“软开关”,输出一个0到1之间的值。1代表“完全保留”,0代表“完全舍弃”,0.5则代表“保留一半”。这使得LSTM能够进行非常精细的信息调控。
让我们一步步拆解LSTM的内部工作流程:
假设我们正在处理上面那句关于法国的话。
第一步:遗忘门——“我们是否需要忘记之前的语言环境?”
首先,LSTM会查看新的输入(比如“能说一口流利的”)和上一时刻的隐藏状态,然后通过一个Sigmoid函数产生一个介于0到1之间的数值,这个数值作用于上一时刻的细胞状态(即传送带上的内容)。
-
计算方式:
f_t = σ(W_f · [h_{t-1}, x_t] + b_f) -
形象化:如果之前的细胞状态里记着“我出生在中国”,但新的信息表明我现在在法国,那么遗忘门可能会决定“忘记”关于中国的信息(输出一个接近0的值),为新的、更相关的信息腾出空间。
第二步:输入门——“我们现在需要将什么新信息存入记忆?”
这一步分为两个部分:
-
输入门层:一个Sigmoid层决定“我们要更新哪些值”。
-
候选值层:一个Tanh层创建一个新的候选值向量
~C_t,这些是可能会被加入到细胞状态的新内容。
-
计算方式:
-
i_t = σ(W_i · [h_{t-1}, x_t] + b_i)(决定更新哪些) -
~C_t = tanh(W_C · [h_{t-1}, x_t] + b_C)(候选的新记忆)
-
-
形象化:输入门识别到“法国”是一个重要的新信息,决定要更新关于“国籍/语言”的记忆。同时,候选值层生成了“法国”和“法语”等关联信息。
第三步:更新细胞状态——“现在,正式更新我们的长期记忆!”
现在,我们把前两步的结果结合起来,对细胞状态进行更新。
-
计算方式:
C_t = f_t * C_{t-1} + i_t * ~C_t -
形象化:
-
f_t * C_{t-1}:将旧的细胞状态乘以遗忘门的输出,忘掉我们决定要忘记的东西(比如“中国”)。 -
i_t * ~C_t:加上输入门输出与候选值的乘积,这相当于我们把关于“法国”的新信息存入细胞状态。 -
现在,细胞状态(传送带)已经从“我出生在中国”更新为“我在法国长大”。
-
第四步:输出门——“基于我们当前的记忆,现在应该输出什么?”
最后,我们需要决定当前时刻的隐藏状态 h_t(也就是输出)。这个隐藏状态是基于我们更新后的细胞状态,但会经过一个过滤。
-
首先,运行一个Sigmoid层(输出门)来决定细胞状态的哪些部分将被输出。
-
然后,我们把细胞状态通过Tanh函数(将其值规范到-1和1之间),并乘以Sigmoid门的输出,这样就只输出我们想要输出的部分。
-
计算方式:
-
o_t = σ(W_o · [h_{t-1}, x_t] + b_o) -
h_t = o_t * tanh(C_t)
-
-
形象化:当输入是“能说一口流利的”时,输出门会从细胞状态(现在记着“法国”)中,提取出最相关的信息“法语”,并将其作为当前隐藏状态
h_t的一部分,最终模型就能正确地预测出“法语”这个词。
总结一下LSTM的优势:
通过这套精密的“传送带+三门”系统,LSTM能够:
-
有选择地记忆:重要信息(如“法国”)可以跨越无数个时间步被保留。
-
有选择地遗忘:无关紧要的中间信息可以被及时清理。
-
有效缓解梯度消失:因为细胞状态的更新主要是加和乘的操作,梯度可以更稳定地流动,使得网络能够学习到长距离的依赖。
LSTM是RNN发展史上的一个里程碑,它让处理长序列数据变得可能,并被广泛应用于机器翻译、文本生成、语音识别等众多领域。理解了LSTM,你就掌握了现代序列建模中最核心的武器之一。
5.6 RNN 的学习方式
我们已经知道了RNN和LSTM是如何工作的,但这样一个复杂的网络,它的“智慧”是从哪里来的呢?答案就是:学习。和我们人类一样,RNN需要通过大量的“练习”来调整自己,让自己变得越来越聪明。这个过程在AI领域被称为训练,而其背后的核心算法,就是我们本节要讲的通过时间的反向传播。
回顾:什么是学习?
对于任何神经网络(包括RNN)来说,“学习”的本质就是寻找一组最优的参数(权重和偏置),使得网络对于给定的输入,能产生尽可能接近标准答案的输出。
这个过程就像一个调音师在调试一台复杂的音响,他通过不断聆听声音的效果,微调上百个旋钮,直到获得最佳音质。在网络中,这些“旋钮”就是成千上万个连接权重。
如何衡量“好坏”?——损失函数
首先,我们需要一个标准来衡量网络的输出是“好”是“坏”。这个标准就是损失函数。比如在文本分类任务中,如果网络应该输出“积极”却输出了“消极”,那么损失值就会很大。我们的目标就是在整个训练数据集上,让这个损失值最小化。
如何找到最优参数?——梯度下降与反向传播
要想最小化损失,最常用的方法就是梯度下降。想象你在一座山上,目标是找到山谷的最低点(损失最小的地方)。你环顾四周,找到最陡峭的下山方向(梯度),然后朝那个方向走一小步(学习率)。重复这个过程,你最终就能到达谷底。
“环顾四周找到最陡峭方向”这个过程,在数学上就是计算梯度——也就是损失函数对于每一个参数的偏导数。它告诉我们,每个参数应该向哪个方向、改变多少,才能最有效地降低损失。
对于普通的前馈神经网络,计算梯度的方法叫做反向传播:先从前向后计算一遍输出和损失,然后从后向前(反向)一层层地计算每个参数对损失的“贡献”(梯度),并根据梯度来更新参数。
RNN的挑战与BPTT的诞生
RNN是跨时间步共享参数的,这意味着它在处理一个序列时,同一个权重矩阵(如 W_xh, W_hh)在每个时间步都会被使用。这就带来了一个挑战:第10个时间步的损失,不仅受到当前 W_hh 的影响,还受到第9、8、7...1步 W_hh 的影响,因为隐藏状态是依次传递的。
为了解决这个问题,学者们提出了通过时间的反向传播,简称 BPTT。你可以把RNN按照时间线“展开”,把它看作一个非常深的、层与层之间共享参数的普通前馈神经网络。
BPTT的工作流程(以“I am Chinese”分类为例):
-
前向传播:
-
输入序列
[“I”, “am”, “Chinese”]。 -
依次计算每个时间步的隐藏状态
h1,h2,h3和最终的输出。 -
计算最终输出与真实标签(如“中文句子”)之间的损失。
-
-
反向传播(穿越时间):
-
从最后一个时间步(t=3)开始,计算损失对
h3的梯度。 -
然后,这个梯度会反向流动到前一个时间步 t=2。因为
h3是由h2和x3计算得来的,所以我们需要计算损失对h2的梯度。这个梯度由两部分组成:一部分直接来自最终的损失,另一部分来自于h3传回来的梯度。 -
同理,梯度会继续反向传播到 t=1。
-
关键一步:在计算每个时间步的梯度时,我们会累积损失对于共享参数(如
W_hh)的梯度。也就是说,W_hh的总梯度 = 在 t=1 的梯度 + 在 t=2 的梯度 + 在 t=3 的梯度。
-
-
参数更新:
-
当所有时间步的梯度都计算完毕后,我们得到了每个参数的总梯度。
-
最后,我们使用梯度下降算法,用这个总梯度来一次性更新所有的共享参数。
-
形象化理解:
把BPTT想象成一段倒放的电影。我们先正常播放(前向传播),看到结局(损失)。然后我们从结局开始倒放(反向传播),仔细分析电影中的每一帧(每个时间步),找出导致这个结局的每一个关键因素(梯度),并记录下来。倒放结束后,我们综合所有分析结果(总梯度),然后去指导演员和导演(更新参数),告诉他们下次如何表演和拍摄,才能得到一个更好的结局(更低的损失)。
正是通过BPTT这种精巧的学习方式,RNN才能够调整其内部参数,学会捕捉序列中的模式和依赖关系,从而完成各种复杂的任务。
5.7 如何解决 RNN 的梯度消失或梯度爆炸问题
在上一节讲解BPTT时,我们提到了梯度会像接力棒一样在时间线上反向传播。然而,这个传递过程并不总是顺利的,它面临着两大“杀手”:梯度消失 和 梯度爆炸。这正是导致普通RNN“健忘”的根本原因。
什么是梯度消失和梯度爆炸?
这要从反向传播中的链式法则说起。在计算梯度时,我们需要将许多个偏导数连乘起来。如果这些偏导数大部分是 小于1 的数,那么连乘的结果会指数级地减小,最终趋近于零,这就是梯度消失。反之,如果这些偏导数大部分是 大于1 的数,连乘结果就会指数级地增大,变成天文数字,这就是梯度爆炸。
-
梯度消失的后果:距离当前时间步越远的时刻,其梯度信号越微弱。这意味着网络无法根据远距离的输入来调整其权重,即“学不会”长距离依赖。这就是普通RNN健忘的根源。
-
梯度爆炸的后果:梯度值过大,导致参数更新步长巨大,网络会变得极其不稳定,权重值会“飞”到溢出,最终无法收敛。
解决方案
1. 应对梯度爆炸:梯度裁剪
梯度爆炸相对容易处理。思路很简单:如果梯度向量的大小超过了某个阈值,我们就将它缩放到这个阈值之内。
-
形象化:这就像给梯度装上一个“安全阀”。当下山的速度太快,快要飞出去的时候,这个安全阀会把你拉回来,让你以安全的速度下降。虽然方向可能不是当前最陡的,但能保证过程的稳定。
-
方法:计算梯度的L2范数(模长),如果大于阈值,就让梯度向量除以它的模长,再乘以阈值。
2. 应对梯度消失:革新网络结构——LSTM与GRU
这是治本的方法。正如我们在5.5节详细讨论的,LSTM 通过引入细胞状态 和门控机制,创造了一条梯度传播的“高速公路”。
-
核心思想:LSTM中细胞状态的更新是 “加性” 的:
C_t = f_t * C_{t-1} + i_t * ~C_t。-
在反向传播时,梯度通过细胞状态这条路径回传,主要进行的是加法操作和逐元素乘法(乘以遗忘门
f_t)。加法操作使得梯度可以几乎不受衰减地直接流过,而遗忘门如果接近1(决定记住),梯度也能很好地保持。 -
这极大地缓解了梯度消失问题,让网络能够学习长程依赖。
-
-
GRU:是LSTM的一个变体,它更加简洁。它将LSTM的三个门合并为两个(更新门和重置门),将细胞状态和隐藏状态合并。但其核心思想与LSTM一致,通过更新门来控制过去信息保留多少,同样创造了梯度流动的捷径,从而有效缓解梯度消失。
3. 其他技术
-
更优的激活函数:使用ReLU或其变种(如Leaky ReLU)代替Sigmoid/Tanh,因为其在正区间的导数为1,有助于梯度流动。
-
权重初始化:使用精心设计的初始化策略(如Xavier初始化、He初始化),可以在训练开始时让梯度的传播处于一个良好的状态。
-
Skip Connections:类似于ResNet中的思想,通过添加跳跃连接,让梯度可以直接绕过一些层,这也是Transformer等现代架构成功的关键之一。
总结
面对梯度问题,我们并非束手无策。对于梯度爆炸,我们有简单有效的梯度裁剪。对于更根本的梯度消失问题,我们通过引入LSTM 和 GRU 这类具有“记忆高速公路”的门控结构,成功地让RNN拥有了长期记忆的能力,使其成为处理序列任务的利器。
5.8 RNN 的其他应用
之前我们主要围绕文本分类(多对一)和序列生成(一对一)来介绍RNN。但实际上,凭借其灵活的输入输出结构,RNN家族(包括LSTM, GRU)可以应对多种多样的任务模式。让我们来看看RNN在其他场景下的精彩应用。
5.8.1 多对一序列
这种模式我们已经非常熟悉了:输入是一个序列,输出是一个单独的值。
-
核心思想:我们只关心最后一个时间步的隐藏状态
h_n,因为它理论上已经编码了整个输入序列的上下文信息。我们将这个h_n传递给一个全连接层和Softmax函数,即可得到一个分类结果或回归值。 -
典型应用:
-
情感分析:输入一段影评文本(序列),输出“积极”或“消极”的情感标签(单个值)。
-
文本分类:输入一篇新闻文章(序列),输出其类别,如“体育”、“财经”、“科技”(单个值)。
-
视频分类:输入一段视频的逐帧特征(序列),输出视频所描述的动作,如“游泳”、“打篮球”(单个值)。
-
(示意图:多个时间步的输入,最终汇聚到一个输出)
5.8.2 多对多序列
这种模式下,输入是一个序列,输出也是一个序列,并且输入和输出的时间步完全对应。这通常被称为序列标注任务。
-
核心思想:每一个时间步都接受一个输入,并产生一个对应的输出。每个时间步的输出都依赖于当前输入以及之前的上下文。
-
典型应用:
-
词性标注:输入一个句子(单词序列),为每个单词输出其词性标签(名词、动词等)。
-
输入:
[“I”, “love”, “Deep”, “Learning”] -
输出:
[“代词”, “动词”, “形容词”, “名词”]
-
-
命名实体识别:输入一个句子,识别出其中的实体(如人名、地名、组织机构名)并分类。
-
输入:
[“乔布斯”, “创立了”, “苹果”, “公司”] -
输出:
[“B-PER”, “O”, “B-ORG”, “I-ORG”](B-开始, I-内部, O-其他)
-
-
股价预测:输入过去10天的股价序列,输出对未来10天的股价预测序列(输入输出长度相等)。
-
(示意图:每个时间步的输入,都对应一个同时刻的输出)
5.8.3 序列到序列
这是RNN最强大、最迷人的应用模式之一,也称为 Seq2Seq。它的特点是:输入一个序列,输出另一个序列,并且两个序列的长度可以完全不同。
Seq2Seq模型通常由两部分组成:
-
编码器:一个RNN(通常是LSTM或GRU),负责读取并“理解”整个输入序列,将其压缩成一个上下文向量。这个向量可以看作是输入序列的“思想”或“摘要”,通常是编码器最后一个时间步的隐藏状态。
-
解码器:另一个RNN(也通常是LSTM或GRU),以编码器产生的上下文向量作为其初始状态,然后开始一步步地生成输出序列。解码器的每一个时间步生成一个输出,并且当前时刻的输出会被作为下一时刻的输入(类似于5.4节的语言模型)。
-
典型应用:
-
机器翻译:这是Seq2Seq的经典应用。输入一种语言的句子(如中文序列),输出另一种语言的句子(如英文序列)。
-
输入:
[“我”, “爱”, “你”] -
输出:
[“I”, “love”, “you”]
-
-
文本摘要:输入一篇长文章(序列),输出一个简短的摘要(另一个序列)。
-
语音识别:输入音频信号的序列,输出对应的文字序列。
-
问答系统:输入一个问题(序列),输出答案(另一个序列)。
-
(示意图:编码器将输入序列编码为上下文向量,解码器再将该向量解码为输出序列)
注意:现代最先进的Seq2Seq模型(如用于翻译的Transformer)已经不再使用简单的RNN,而是基于自注意力机制,但其“编码-解码”的核心思想正是源于此。理解Seq2Seq是通往理解现代自然语言处理技术的大门。
通过以上这些灵活的应用模式,我们可以看到,RNN及其变体几乎统治了2018年之前的所有序列处理任务,为AI理解语言、语音和时间序列数据立下了汗马功劳。
更多推荐


所有评论(0)