Softmax函数及其导数
Softmax函数及其导数本文翻译自The Softmax function and its derivative基础概念Softmax函数的输入是N维的随机真值向量,输出是另一个N维的真值向量,且值的范围是(0,1)(0,1),和为1.0。即映射:S(a)=RN→RNS(\textbf{a})=\mathbb{R}^N\rightarrow \mathbb{R}^N:S(a):⎡⎣⎢⎢⎢a
Softmax函数及其导数
本文翻译自The Softmax function and its derivative
基础概念
Softmax函数的输入是N维的随机真值向量,输出是另一个N维的真值向量,
且值的范围是(0,1)<script id="MathJax-Element-3287" type="math/tex">(0,1)</script>,和为1.0。即映射:S(a)=RN→RN<script id="MathJax-Element-3288" type="math/tex">S(\textbf{a})=\mathbb{R}^N\rightarrow \mathbb{R}^N</script>:
其中每一个元素的公式为:
显然Sj<script id="MathJax-Element-3291" type="math/tex">S_j</script>总是正的~(因为指数);因为所有的Sj<script id="MathJax-Element-3292" type="math/tex">S_j</script>的和为1,所以有Sj<1<script id="MathJax-Element-3293" type="math/tex">S_j<1</script>,因此它的范围是(0,1)<script id="MathJax-Element-3294" type="math/tex">(0,1)</script>。例如,一个含有三个元素的向量[1.0,2.0,3.0]<script id="MathJax-Element-3295" type="math/tex">\left [ 1.0,\,2.0,\,3.0 \right ]</script>被转化为[0.09,0.24,0.67]<script id="MathJax-Element-3296" type="math/tex">\left [ 0.09,\,0.24,\,0.67 \right ]</script>。
转化后的元素与原始的对应的元素位置上保持一致,且和为1。我们将原始的向量拉伸为[1.0,2.0,5.0]<script id="MathJax-Element-3297" type="math/tex">\left [ 1.0,\,2.0,\,5.0 \right ]</script>,得到变换后的[0.02,0.05,0.93]<script id="MathJax-Element-3298" type="math/tex">\left [ 0.02,\,0.05,\,0.93 \right ]</script>,同样具有前面的性质。注意此时因为最后一个元素(5.0)距离前面两个元素(1.0和2.0)较远,因此它的输出的softmax值占据了和1.0的大部分(0.93)。softmax并不是只选择一个最大元素,而是将向量分解为整个(1.0)的几部分,最大的输入元素得到一个比例较大的部分,但其他元素各自也获得对应的部分。
概率解释
softmax的性质(所有输出的值范围是(0,1)<script id="MathJax-Element-3299" type="math/tex">(0,1)</script>且和为1.0)使其在机器学习的概率解释中广泛使用。尤其是在多类别分类任务中,我们总是给输出结果对应的类别附上一个概率,即如果我们的输出类别有N种,我们就输出一个N维的概率向量且和为1.0。每一维的值对应一种类别的概率。我们可以将softmax解释如下:
其中,y<script id="MathJax-Element-3301" type="math/tex">y</script>是输出的N个类别中的某个(取值为
向量计算的准备
在深入理解计算softmax的导数之前,我们先了解向量计算的一些基础知识。
Softmax从根本上来说是一种向量函数。它将向量作为输入并输出另一个向量。换言之,它有多个输入和输出,因此我们不能直接就尝试求”softmax的导数”,我们首先要明确:
- 我们想要计算softmax的哪个组成成分(输出的某元素)的导数。
- 由于softmax具有多个输入,所以要计算关于哪个输入元素的偏导数。
听起来好像很复杂,但这正是为什么定义向量计算的原因。 我们正在寻找的偏导数是:
这是第i<script id="MathJax-Element-3306" type="math/tex">i</script>个输出关于第
因为softmax函数是一个RN→RN<script id="MathJax-Element-3309" type="math/tex">\mathbb{R}^N\rightarrow \mathbb{R}^N</script>的函数,所以我们计算得到的导数是一个雅可比矩阵:
在机器学习的文献中,常常用术语梯度来表示通常所说的导数。严格来说,梯度只是为标量函数来定义的,例如机器学习中的损失函数;对于像softmax这样的向量函数,说是“梯度”是不准确的;雅可比是一个向量函数的全部的导数,大多数情况下我们会说“导数”。
softmax的导数
对任意的i<script id="MathJax-Element-3311" type="math/tex">i</script>和
我们将使用链式法则来计算导数,即对于f(x)=g(x)h(x)<script id="MathJax-Element-3315" type="math/tex">f(x)=\frac{g(x)}{h(x)}</script>:
在我们的情况下,有:
注意对于hi<script id="MathJax-Element-3318" type="math/tex">h_i</script>,无论求其关于哪个aj<script id="MathJax-Element-3319" type="math/tex">a_j</script>的导数,结果都是eaj<script id="MathJax-Element-3320" type="math/tex">e^{a_j}</script>,但是对于gi<script id="MathJax-Element-3321" type="math/tex">g_i</script>就不同了。
gi<script id="MathJax-Element-3322" type="math/tex">g_i</script>关于aj<script id="MathJax-Element-3323" type="math/tex">a_j</script>的导数是eaj<script id="MathJax-Element-3324" type="math/tex">e^{a_j}</script>当且仅当i=j<script id="MathJax-Element-3325" type="math/tex">i=j</script>;否则结果为0。
让我们回到DjSi<script id="MathJax-Element-3326" type="math/tex">D_jS_i</script>;我们先考虑i=j<script id="MathJax-Element-3327" type="math/tex">i=j</script>的情况。根据链式法则我们有:
简单起见,我们使用∑<script id="MathJax-Element-3329" type="math/tex">\sum</script>表示∑Nk=1eak<script id="MathJax-Element-3330" type="math/tex">\sum_{k=1}^{N}e^{a_k}</script>。继续化简下:
最后的公式使用其自身来表示(Si<script id="MathJax-Element-3332" type="math/tex">(S_i</script>和Sj)<script id="MathJax-Element-3333" type="math/tex">S_j)</script>,这在包含指数函数时是一个常用的技巧。
类似的,考虑i≠j<script id="MathJax-Element-3334" type="math/tex">i\neq j</script>的情况:
总结如下:
在文献中我们常常会见到各种各样的”浓缩的”公式,一个常见的例子是使用克罗内克函数:
于是我们有:
在文献中也有一些其它的表述:
- 在雅可比矩阵中使用单位矩阵I<script id="MathJax-Element-3339" type="math/tex">I</script>来替换
δ <script id="MathJax-Element-3340" type="math/tex">\delta</script>,I<script id="MathJax-Element-3341" type="math/tex">I</script>使用元素的矩阵形式表示了δ <script id="MathJax-Element-3342" type="math/tex">\delta</script>。 - 使用”1”作为函数名而不是克罗内克δ<script id="MathJax-Element-3343" type="math/tex">\delta</script>,如下所示:DjSi=Si(1(i=j)−Sj)<script id="MathJax-Element-3344" type="math/tex">D_jS_i=S_i(1(i=j)-S_j)</script>。这里1(i=j)意味着当i=j<script id="MathJax-Element-3345" type="math/tex">i=j</script>时值为1,否则为0。
当我们想要计算依赖于softmax导数的更复杂的导数时,“浓缩”符号会很有用; 否则我们必须在任何地方完整的写出公式。
计算softmax和数值稳定性
对于一个给定的向量,使用Python来计算softmax的简单方法是:
def softmax(x):
"""Compute the softmax of vector x."""
exps = np.exp(x)
return exps / np.sum(exps)
使用前面定义的softmax函数计算一个三维的向量:
In [146]: softmax([1, 2, 3])
Out[146]: array([ 0.09003057, 0.24472847, 0.66524096])
然而当我们使用该函数计算较大的值时(或者大的负数时),会出现一个问题:
In [148]: softmax([1000, 2000, 3000])
Out[148]: array([ nan, nan, nan])
Numpy使用的浮点数的数值范围是有限的。对于float64,最大可表示数字的大小为10308<script id="MathJax-Element-3376" type="math/tex">10^{308}</script>。
softmax函数中的求幂运算可以轻松超过这个数字,即使是相当适中的输入。避免这个问题的一个好方法是通过规范输入使其不要太大或者太小,通过观察我们可以使用任意的常量C,如下所示:
然后将这个变量转换到指数上:
因为C是一个随机的常量,所以我们可以写为:
D也是一个任意常量。对任意D,这个公式等价于前面的式子,这让我们能够更好的进行计算。对于D,一个比较好的选择是所有输入的最大值的负数:
假定输入本身彼此相差不大,这会使输入转换到接近于0的范围。最重要的是,它将所有的输入转换为负数(除最大值外,最大值变为0)。很大的负指数结果会趋于0而不是无穷,这就让我们很好的避免了出现NaN的结果。
def stablesoftmax(x):
"""Compute the softmax of vector x in a numerically
stable way."""
shiftx = x - np.max(x)
exps = np.exp(shiftx)
return exps / np.sum(exps)
现在我们有:
In [150]: stablesoftmax([1000, 2000, 3000])
Out[150]: array([ 0., 0., 1.])
请注意,这仍然是不完美的,因为数学上softmax永远不会真的产生零,但这比NaN好得多,且由于输入之间的距离非常大,所以无论如何都会得到非常接近于零的结果。
softmax层及其导数
softmax常用于机器学习中,特别是逻辑斯特回归:softmax层,其中我们将softmax应用于全连接层(矩阵乘法)的输出,如图所示。
在这个图中,我们有一个具有N个特征的输入x和T个可能的输出类别。权重矩阵W用于将x转换成具有T元素的向量(在机器学习的文献中称为“logits”),并且softmax函数用于将logits转换成表示属于某一类别的概率。
我们如何计算这个“softmax层”的导数(先进行全连接矩阵乘法,然后是softmax)?当然是使用链式规则!
在我们开始之前的一个重要的观点:你可能会认为x是计算其导数的自然变量(natural variable)。但事实并非如此。实际上,在机器学习中,我们通常希望找到最佳的权重矩阵W,因此我们希望用梯度下降的每一步来更新权重。因此,我们将计算该层的关于W的导数。
我们首先将这个图改写为向量函数的组合。首先我们定义矩阵乘法g(W)<script id="MathJax-Element-3623" type="math/tex">g(W)</script>,即映射:RNT→RT<script id="MathJax-Element-3624" type="math/tex">\mathbb{R}^{NT}\rightarrow \mathbb{R}^T</script>。因为输入(矩阵W)N×T<script id="MathJax-Element-3625" type="math/tex">N\times T</script>个元素,输出有T个元素。
接下来我们来考虑softmax,如果我们定义logits的向量是λ<script id="MathJax-Element-3626" type="math/tex">\lambda </script>,我们有:RT→RT<script id="MathJax-Element-3627" type="math/tex">\mathbb{R}^{T}\rightarrow \mathbb{R}^T</script>。总体来说,我们有:
使用多变量的链式法则,得到P(W)<script id="MathJax-Element-3629" type="math/tex">P(W)</script>的雅可比矩阵:
我们之前已经计算过雅可比矩阵;只不过此时是对g(W)求解。因此g是一个非常简单的函数,因此计算雅可比矩阵很简单。唯一要注意的是
正确计算相应的索引。因为g(W)<script id="MathJax-Element-3631" type="math/tex">g(W)</script>:RNT→RT<script id="MathJax-Element-3632" type="math/tex">\mathbb{R}^{NT}\rightarrow \mathbb{R}^T</script>,所以它的雅可比矩阵是T<script id="MathJax-Element-3633" type="math/tex">T</script>行,
在某种意义上,权重矩阵W被“线性化”为长度为NT的向量。 如果您熟悉多维数组的内存布局,应该很容易理解它是如何完成的。
在我们的例子中,我们可以做的一件事就是按照行主次序对其进行线性化处理,第一行是连续的,接着是第二行,等等。Wij<script id="MathJax-Element-3636" type="math/tex">W_{ij}</script>
在雅可比矩阵中的列号是(i−1)N+j<script id="MathJax-Element-3637" type="math/tex">(i-1)N+j</script>。为了计算Dg<script id="MathJax-Element-3638" type="math/tex">Dg</script>,让我们回顾g1<script id="MathJax-Element-3639" type="math/tex">g_1</script>:
因此:
我们使用同样的策略来计算g2⋯gT<script id="MathJax-Element-3642" type="math/tex">g_2\cdots g_T</script>,我们可以得到雅可比矩阵:
最后从另一个角度来这个问题,如果我们将W的索引分解为i和j,可以得到:
在雅可比矩阵中表示第t<script id="MathJax-Element-4758" type="math/tex">t</script>行,
最后,为了计算softmax层的完整的雅可比矩阵,我们只需要计算DS<script id="MathJax-Element-4760" type="math/tex">DS</script>和Dg<script id="MathJax-Element-4761" type="math/tex">Dg</script>间的乘积。注意P(W)<script id="MathJax-Element-4762" type="math/tex">P(W)</script>:RNT→RT<script id="MathJax-Element-4763" type="math/tex">\mathbb{R}^{NT}\rightarrow \mathbb{R}^T</script>,因此雅可比矩阵的维度可以确定。因此DS<script id="MathJax-Element-4764" type="math/tex">DS</script>是T×T<script id="MathJax-Element-4765" type="math/tex">T\times T</script>,Dg<script id="MathJax-Element-4766" type="math/tex">Dg</script>是T×NT<script id="MathJax-Element-4767" type="math/tex">T\times NT</script>的,它们的乘积DP<script id="MathJax-Element-4768" type="math/tex">DP</script>是T×NT<script id="MathJax-Element-4769" type="math/tex">T\times NT</script>的。
在文献中,你会看到softmax层的导数大大减少了。因为涉及的两个函数很简单而且很常用。 如果我们仔细计算DS<script id="MathJax-Element-4770" type="math/tex">DS</script>的行和Dg<script id="MathJax-Element-4771" type="math/tex">Dg</script>的列之间的乘积:
Dg<script id="MathJax-Element-4773" type="math/tex">Dg</script>大多数为0,所以最终的结果很简单,仅当i=k<script id="MathJax-Element-4774" type="math/tex">i=k</script>时Dijgk<script id="MathJax-Element-4775" type="math/tex">D_{ij}g_k</script>不为0;然后它等于xj<script id="MathJax-Element-4776" type="math/tex">x_j</script>。因此:
因此完全可以在没有实际雅可比矩阵乘法的情况下计算softmax层的导数; 这很好,因为矩阵乘法很耗时!由于全连接层的雅可比矩阵是稀疏的,我们可以避免大多数计算。
Softmax和交叉熵损失
我们刚刚看到softmax函数如何用作机器学习网络的一部分,以及如何使用多元链式规则计算它的导数。当我们处理这个问题的时候,经常看到损失函数和softmax一起使用来训练网络:交叉熵。
交叉熵有一个有趣的概率和信息理论解释,但在这里我只关注其使用机制。对于两个离散概率分布p<script id="MathJax-Element-4778" type="math/tex">p</script>和
其中k<script id="MathJax-Element-4781" type="math/tex">k</script>遍历分布定义的随机变量的所有的可能的值。具体而言,在我们的例子中有
如果我们从softmax的输出P<script id="MathJax-Element-4785" type="math/tex">P</script>(一个概率分布)来考量。其它的概率分布是”正确的”类别输出,通常定义为
其中k<script id="MathJax-Element-4789" type="math/tex">k</script>遍历所有的输出类别,
实际上,我们把y<script id="MathJax-Element-4797" type="math/tex">y</script>当作一个常量,仅使用
第
xent<script id="MathJax-Element-4805" type="math/tex">xent</script>的雅可比矩阵是1×T<script id="MathJax-Element-4806" type="math/tex">1\times T</script>的矩阵(一个行向量)。因为输出是一个标量且我们有T<script id="MathJax-Element-4807" type="math/tex">T</script>个输出(向量
现在回顾下P<script id="MathJax-Element-4811" type="math/tex">P</script>可以表示为输入为权值的函数:
我们可以再次使用多元链式法则来计算xent<script id="MathJax-Element-4814" type="math/tex">xent</script>关于W<script id="MathJax-Element-4815" type="math/tex">W</script>的梯度:
我们来检查一下雅可比行矩阵的维数。我们已经计算过了DP(W)<script id="MathJax-Element-4817" type="math/tex">DP(W)</script>,它是T×NT<script id="MathJax-Element-4818" type="math/tex">T\times NT</script>的。Dxent(P(W))<script id="MathJax-Element-4819" type="math/tex">Dxent(P(W))</script>是1×T<script id="MathJax-Element-4820" type="math/tex">1\times T</script>的,所以得到的
雅可比矩阵Dxent(W)<script id="MathJax-Element-4821" type="math/tex">Dxent(W)</script>是1×NT<script id="MathJax-Element-4822" type="math/tex">1\times NT</script>的。这是有意义的,因为整个网络有一个输出(交叉熵损失,是一个标量)和NT<script id="MathJax-Element-4823" type="math/tex">NT</script>个输入(权重)。
同样的,有一个简单的方式来找到Dxent(W)<script id="MathJax-Element-4824" type="math/tex">Dxent(W)</script>的简单公式,因为矩阵乘法中的许多元素最终会被消除。注意到xent(P)<script id="MathJax-Element-4825" type="math/tex">xent(P)</script>只依赖于P<script id="MathJax-Element-4826" type="math/tex">P</script>的
第
其中,Dyxent=−1Py<script id="MathJax-Element-4830" type="math/tex">D_yxent=-\frac{1}{P_y}</script>。回到整个的雅可比矩阵Dxent(W)<script id="MathJax-Element-4831" type="math/tex">Dxent(W)</script>,使Dxent(P)<script id="MathJax-Element-4832" type="math/tex">Dxent(P)</script>乘以D(P(W))<script id="MathJax-Element-4833" type="math/tex">D(P(W))</script>的每一列,得到结果的行向量的每一个
元素。回顾用行向量表示的按行优先的“线性化”的整个权重矩阵W。清晰起见,我们将使用i<script id="MathJax-Element-4834" type="math/tex">i</script>和
iN+j<script id="MathJax-Element-4837" type="math/tex">iN+j</script>个元素):
因为在Dkxent(P)<script id="MathJax-Element-4839" type="math/tex">D_kxent(P)</script>中只有第y<script id="MathJax-Element-4840" type="math/tex">y</script>个元素是非0的,所以我们可以得到下式:
根据我们的定义,Py=Sy<script id="MathJax-Element-4842" type="math/tex">P_y=S_y</script>,所以可得:
即使最终的结果很简洁清楚,但是我们不一定非要这样做。公式Dijxent(W)<script id="MathJax-Element-4844" type="math/tex">D_{ij}xent(W)</script>可能最终成为一个和的形式(或者某些和的和)。关于雅可比矩阵的这些技巧可能并没有太大意义,因为计算机可以完成所有的工作。我们需要做的就是计算出单个的雅矩阵,这通常毕竟容易,因为它们是更简单的非复合函数。这技术体现了多元链式法则的美妙和实用性。
更多推荐



所有评论(0)