在此之前,我们已经学习了前馈网络的两种结构——多层感知器卷积神经网络,这两种结构有一个特点,就是假设输入是一个独立的没有上下文联系的单位,比如输入是一张图片,网络识别是狗还是猫。但是对于一些有明显的上下文特征的序列化输入,比如预测视频中下一帧的播放内容,那么很明显这样的输出必须依赖以前的输入, 也就是说网络必须拥有一定的”记忆能力”。为了赋予网络这样的记忆力,一种特殊结构的神经网络——递归神经网络(Recurrent Neural Network)便应运而生了。网上对于RNN的介绍多不胜数,这篇《Recurrent Neural Networks Tutorial》对于RNN的介绍非常直观,里面手把手地带领读者利用python实现一个RNN语言模型,强烈推荐。为了不重复作者 Denny Britz的劳动,本篇将简要介绍RNN,并强调RNN训练的过程与多层感知器的训练差异不大(至少比CNN简单),希望能给读者一定的信心——只要你理解了多层感知器,理解RNN便不是事儿:-)。

RNN的基本结构

首先有请读者看看我们的递归神经网络的容貌:
这里写图片描述
乍一看,好复杂的大家伙,没事,老样子,看我如何慢慢将其拆解,正所谓见招拆招,我们来各个击破。
上图左侧是递归神经网络的原始结构,如果先抛弃中间那个令人生畏的闭环,那其实就是简单”输入层=>隐藏层=>输出层”的三层结构,我们在多层感知器的介绍中已经非常熟悉,然而多了一个非常陌生的闭环,也就是说输入到隐藏层之后,隐藏层还会给自己也来一发,环环相扣,晕乱复杂。
我们知道,一旦有了环,就会陷入“先有蛋还是先有鸡”的逻辑困境,为了跳出困境我们必须人为定义一个起始点,按照一定的时间序列规定好计算顺序,做到有条不紊,于是实际上我们会将这样带环的结构展开成一个序列网络,也就是上图右侧被“unfold”之后的结构。先别急着能理解RNN,我们来点轻松的,先介绍这样的序列化网络结构包含的参数记号:

  • 网络某一时刻的输入 xt <script type="math/tex" id="MathJax-Element-1">x_t</script>,和之前介绍的多层感知器的输入一样, xt <script type="math/tex" id="MathJax-Element-2">x_t</script>是一个 n <script type="math/tex" id="MathJax-Element-3">n</script>维向量,不同的是递归网络的输入将是一整个序列,也就是x=[x1,...,xt1,xt,xt+1,...xT]<script type="math/tex" id="MathJax-Element-4">x=[x_1,...,x_{t-1},x_{t},x_{t+1},...x_T]</script>,对于语言模型,每一个 xt <script type="math/tex" id="MathJax-Element-5">x_t</script>将代表一个词向量,一整个序列就代表一句话。
  • ht <script type="math/tex" id="MathJax-Element-6">h_t</script>代表时刻 t <script type="math/tex" id="MathJax-Element-7">t</script>的隐藏状态
  • ot<script type="math/tex" id="MathJax-Element-8">o_t</script>代表时刻 t <script type="math/tex" id="MathJax-Element-9">t</script>的输出
  • 输入层到隐藏层直接的权重由U<script type="math/tex" id="MathJax-Element-10">U</script>表示,它将我们的原始输入进行抽象作为隐藏层的输入
  • 隐藏层到隐藏层的权重 W <script type="math/tex" id="MathJax-Element-11">W</script>,它是网络的记忆控制者,负责调度记忆。
  • 隐藏层到输出层的权重V<script type="math/tex" id="MathJax-Element-12">V</script>,从隐藏层学习到的表示将通过它再一次抽象,并作为最终输出。

RNN的Forward阶段

上一小节我们简单了解了网络的结构,并介绍了其中一些记号,是时候介绍它具体的运作过程了。首先在 t=0 <script type="math/tex" id="MathJax-Element-13">t=0</script>的时刻, U,V,W <script type="math/tex" id="MathJax-Element-14">U,V,W</script>都被随机初始化好, h0 <script type="math/tex" id="MathJax-Element-15">h_0</script>通常初始化为0,然后进行如下计算:

s1=Ux1+Wh0h1=f(s1)o1=g(Vh1)
<script type="math/tex; mode=display" id="MathJax-Element-16">s_1=Ux_1+Wh_0\\h_1=f(s_1)\\o_1=g(Vh_1)</script>这样时间就向前推进,此时的状态 h1 <script type="math/tex" id="MathJax-Element-17">h_1</script>作为时刻0的记忆状态将参与下一次的预测活动,也就是
s2=Ux2+Wh1h2=f(s2)o2=g(Vh2)
<script type="math/tex; mode=display" id="MathJax-Element-18">s_2=Ux_2+Wh_1\\h_2=f(s_2)\\o_2=g(Vh_2)</script>,以此类推
st=Uxt+Wht1ht=f(Uxt+Wht1)ot=g(Vht)
<script type="math/tex; mode=display" id="MathJax-Element-19">s_t=Ux_{t}+Wh_{t-1}\\h_t=f(Ux_{t}+Wh_{t-1})\\o_t=g(Vh_t)</script>其中 f <script type="math/tex" id="MathJax-Element-20">f</script>可以是tanh,relu,logistic<script type="math/tex" id="MathJax-Element-21">tanh,relu,logistic</script>任君选择, g <script type="math/tex" id="MathJax-Element-22">g</script>通常是softmax<script type="math/tex" id="MathJax-Element-23">softmax</script>也可以是其他,也是随君所欲。
值得注意的是,我们说递归神经网络拥有记忆能力,而这种能力就是通过 W <script type="math/tex" id="MathJax-Element-24">W</script>将以往的输入状态进行总结,而作为下次输入的辅助。可以这样理解隐藏状态:
h=f(+)
<script type="math/tex; mode=display" id="MathJax-Element-25">h=f(现有的输入+过去记忆总结)</script>

RNN的Backward阶段

上一小节我们说到了RNN如何做序列化预测,也就是如何一步步预测出 o1,o2,....ot1,ot,ot+1..... <script type="math/tex" id="MathJax-Element-3470">o_1,o_2,....o_{t-1},o_t,o_{t+1}.....</script>,接下来我们来了解网络的知识 U,V,W <script type="math/tex" id="MathJax-Element-3471">U,V,W</script>是如何炼成的。
其实没有多大新意,我们还是利用在之前讲解多层感知器卷积神经网络用到的backpropagation方法。也就是将输出层的误差 Cost <script type="math/tex" id="MathJax-Element-3472">Cost</script>,求解各个权重的梯度 U,V,W <script type="math/tex" id="MathJax-Element-3473">\nabla U,\nabla V ,\nabla W</script>,然后利用梯度下降法更新各个权重。现在问题就是如何求解各个权重的梯度,其它的所有东西都在之前介绍中谈到了,所有的trick都可以复用。
由于是序列化预测,那么对于每一时刻 t <script type="math/tex" id="MathJax-Element-3474">t</script>,网络的输出ot<script type="math/tex" id="MathJax-Element-3475">o_t</script>都会产生一定误差 et <script type="math/tex" id="MathJax-Element-3476">e_t</script>,误差的选择任君喜欢,可以是cross entropy也可以是平方误差等等。那么总的误差为 E=tet <script type="math/tex" id="MathJax-Element-3477">E=\sum_t e_t</script>,我们的目标就是要求取

U=EU=tetUV=EV=tetVW=EW=tetW
<script type="math/tex; mode=display" id="MathJax-Element-3478">\nabla U=\frac{\partial E}{\partial U}=\sum_t\frac{\partial e_t}{\partial U} \\\nabla V=\frac{\partial E}{\partial V}=\sum_t\frac{\partial e_t}{\partial V} \\\nabla W=\frac{\partial E}{\partial W}=\sum_t\frac{\partial e_t}{\partial W} </script>我们知道输出 ot=g(Vst) <script type="math/tex" id="MathJax-Element-3479">o_t=g(Vs_t)</script>,对于任意的 Cost <script type="math/tex" id="MathJax-Element-3480">Cost</script>函数,求取 V <script type="math/tex" id="MathJax-Element-3481">\nabla V</script>将是简单的,我们可以直接求取每个时刻的 etV <script type="math/tex" id="MathJax-Element-3482">\frac{\partial e_t}{\partial V} </script>,由于它不存在和之前的状态依赖,可以直接求导取得,然后简单地求和即可。我们重点关注 W,U <script type="math/tex" id="MathJax-Element-3483">\nabla W,\nabla U</script>的计算。
回忆之前我们介绍 多层感知器的backprop算法,我们知道算法的trick是定义一个 δ=es <script type="math/tex" id="MathJax-Element-3484">\delta=\frac{\partial e}{\partial s}</script>,首先计算出输出层的 δL <script type="math/tex" id="MathJax-Element-3485">\delta^L</script>,再向后传播到各层 δL1,δL2,.... <script type="math/tex" id="MathJax-Element-3486">\delta^{L-1},\delta^{L-2},....</script>,那么如何计算 δ <script type="math/tex" id="MathJax-Element-3487">\delta</script>呢?先看下图:
这里写图片描述
之前我们推导过,只要关注当前层次发射出去的链接即可,也就是
δht=(VTδot+WTδht+1).f(st)
<script type="math/tex; mode=display" id="MathJax-Element-3488">\delta_t^h=(V^T\delta_t^o+W^T\delta_{t+1}^h).*f'(s_t)</script>
只要计算出所有的 δot,δht <script type="math/tex" id="MathJax-Element-3489">\delta^o_t,\delta^h_t</script>,就可以通过以下计算出 W,U <script type="math/tex" id="MathJax-Element-3490">\nabla W,\nabla U</script>:
W=tδh0,th0,t1,...,δh0,thi,t1,...,δh0,thm,t1...δhj,th0,t1,...,δhj,thi,t1,...,δhj,thm,t1...δhn,th0,t1,...,δhn,thi,t1,...,δhn,thm,t1=tδht×ht1U=tδh0,tx0,t,...,δh0,txi,t,...,δh0,txm,t...δhj,tx0,t,...,δhj,txi,t,...,δhj,txm,t...δhn,tx0,t,...,δhn,txi,t,...,δhn,txm,t=tδht×xt
<script type="math/tex; mode=display" id="MathJax-Element-3491"> \nabla W=\sum_t \left[ \matrix{ {\delta_{0,t}^hh_{0,t-1},...,\delta_{0,t}^hh_{i,t-1},...,\delta_{0,t}^hh_{m,t-1}} \\... \\{\delta_{j,t}^hh_{0,t-1},...,\delta_{j,t}^hh_{i,t-1},...,\delta_{j,t}^hh_{m,t-1}} \\... \\{\delta_{n,t}^hh_{0,t-1},...,\delta_{n,t}^hh_{i,t-1},...,\delta_{n,t}^hh_{m,t-1}} } \right] = \sum_t \delta_t^h\times h_{t-1} \\\nabla U=\sum_t \left[ \matrix{ {\delta_{0,t}^hx_{0,t},...,\delta_{0,t}^hx_{i,t},...,\delta_{0,t}^hx_{m,t}} \\... \\{\delta_{j,t}^hx_{0,t},...,\delta_{j,t}^hx_{i,t},...,\delta_{j,t}^hx_{m,t}} \\... \\{\delta_{n,t}^hx_{0,t},...,\delta_{n,t}^hx_{i,t},...,\delta_{n,t}^hx_{m,t}} } \right] = \sum_t \delta_t^h\times x_t </script>
其中 × <script type="math/tex" id="MathJax-Element-3492">\times</script>表示两个向量的外积。这样看来,只要你熟悉MLP的backprop算法,RNN写起程序来和MLP根本没有多大差异!手写naive的demo至少比CNN容易很多。

RNN的训练困难

虽然上一节中,我们强调了RNN的训练程序和MLP没太大差异,虽然写程序容易,但是训练起来却是千难万阻。为什么呢?因为我们的网络是根据输入而展开的,输入越长,展开的网络越深,那么对于“深度”网络训练有什么困难呢?最常见的是“gradient explode”和“gradient vanish”。这种问题在RNN中如何体现呢?为了强调这个问题,我们模仿Yoshua Bengio的论文《On the difficulty of training recurrent neural networks》的推导,重写一下RNN的梯度求解过程,为了推导方便,我们人为地为 W,U <script type="math/tex" id="MathJax-Element-49">W,U</script>打上标签 Wt,Ut <script type="math/tex" id="MathJax-Element-50">W^t,U^t</script>,即认为当确定好时间长度 T <script type="math/tex" id="MathJax-Element-51">T</script>,RNN就变成普通的MLP。打上标签后的RNN变成如下:
这里写图片描述
假如对于时刻t+1<script type="math/tex" id="MathJax-Element-52">t+1</script>产生的误差 et+1 <script type="math/tex" id="MathJax-Element-53">e_{t+1}</script>,我们想计算它对于 W1,W2,....,WtWt+1 <script type="math/tex" id="MathJax-Element-54">W^1,W^2,....,W^{t},W^{t+1}</script>的梯度,可以如下计算:

et+1Wt+1=et+1ht+1ht+1Wt+1
<script type="math/tex; mode=display" id="MathJax-Element-55">\frac{\partial e_{t+1}}{\partial W^{t+1}}=\frac{\partial e_{t+1}}{\partial h^{t+1}}\frac{\partial h_{t+1}}{\partial W^{t+1}}</script>
et+1Wt=et+1ht+1ht+1hthtWt
<script type="math/tex; mode=display" id="MathJax-Element-56">\frac{\partial e_{t+1}}{\partial W^{t}}=\frac{\partial e_{t+1}}{\partial h^{t+1}}\frac{\partial h_{t+1}}{\partial h^{t}}\frac{\partial h_{t}}{\partial W^{t}}</script>
et+1Wt1=et+1ht+1ht+1hththt1ht1Wt1
<script type="math/tex; mode=display" id="MathJax-Element-57">\frac{\partial e_{t+1}}{\partial W^{t-1}}=\frac{\partial e_{t+1}}{\partial h^{t+1}}\frac{\partial h_{t+1}}{\partial h^{t}}\frac{\partial h_{t}}{\partial h^{t-1}}\frac{\partial h_{t-1}}{\partial W^{t-1}}</script>
......
<script type="math/tex; mode=display" id="MathJax-Element-58">......</script>
反复运用链式法则,我们可以求出每一个 W1,W2,....,WtWt+1 <script type="math/tex" id="MathJax-Element-59">\nabla W^1,\nabla W^2,....,\nabla W^{t},\nabla W^{t+1}</script>,需要注意的是,实际RNN模型对于 W,U <script type="math/tex" id="MathJax-Element-60">W,U</script>都是不打标签的,也就是在不同时刻都是共享同样的参数,这样可以大大减少训练参数,和CNN的共享权重类似。对于共享参数的RNN,我们只需将上述的一系列式子抹去标签并求和,就可以得到Yoshua Bengio论文中所推导的梯度计算式子:
etW=1ktethtk<ithihi1+hkW
<script type="math/tex; mode=display" id="MathJax-Element-61">\frac{\partial e_{t}}{\partial W}=\sum_{1\le k \le t} \frac{\partial e_{t}}{\partial h^{t}}\prod_{k 其中 +hkW <script type="math/tex" id="MathJax-Element-62">\frac{\partial^+ h_{k}}{\partial W}</script>代表不利用链式法则直接求导,也就是假如对于函数 f(h(x)) <script type="math/tex" id="MathJax-Element-63">f(h(x))</script>,对其直接求导结果如下:
f(h(x))x=f(h(x))
<script type="math/tex; mode=display" id="MathJax-Element-64">\frac{\partial f(h(x))}{\partial x}=f'(h(x))</script>也就是将 h(x) <script type="math/tex" id="MathJax-Element-65">h(x)</script>看成常数了。网上许多RNN教程都用Yoshua Bengio类似的推导,却省略了这个小步骤,使得初学者常常搞得晕头转向,摸不着头脑。论文中证明了:
||k<ithihi1||ηtk
<script type="math/tex; mode=display" id="MathJax-Element-66">||\prod_{k η<1 <script type="math/tex" id="MathJax-Element-67">\eta < 1</script>时,就会出现”gradient vanish”问题,而当 η>1 <script type="math/tex" id="MathJax-Element-68">\eta > 1</script>时,“gradient explode”也就产生了。
为了克服”gradient vanish”的问题,LSTM和GRU模型便后续被推出了,为什么LSTM和GRU可以克服gradient vanish问题呢?由于它们都有特殊的方式存储”记忆”,那么以前gradient比较大的”记忆”不会像简单的RNN一样马上被抹除,因此可以一定程度上克服gradient vanish问题。
另一个简单的技巧可以用来克服gradient explode的问题就是gradient clipping,也就是当你计算的gradient超过阈值 c <script type="math/tex" id="MathJax-Element-69">c</script>的或者小于阈值 c<script type="math/tex" id="MathJax-Element-70">-c</script>时候,便把此时的gradient设置成 c <script type="math/tex" id="MathJax-Element-71">c</script>或 c<script type="math/tex" id="MathJax-Element-72">-c</script>。这种trick的表现形式如下图虚线所示:
这里写图片描述
上图所示是RNN的Error Sufface,可以看到RNN的Error Sufface要么非常陡峭,要么非常平坦,如果不采取任何措施,当你的参数在某一次更新之后,刚好碰到陡峭的地方,此时gradient变得非常大,那么你的参数更新也会非常大,很容易导致震荡问题。而如果你采取了gradient clipping这个技巧,那么即使你不幸碰到陡峭的地方,gradient也不会explode,因为被你限制在某个阈值 c <script type="math/tex" id="MathJax-Element-73"></script>。
有趣的是,正是因为训练深度网络的困难,才导致神经网络这种古老模型沉寂了几十年,不过现在硬件的发展,训练数据的增多,神经网络重新得以复苏,并以重新以深度学习的外号杀出江湖。

参考引用

《Recurrent Neural Networks Tutorial》
《On the difficulty of training recurrent neural networks》

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐