二分类和多分类交叉熵函数区别详解
二分类和多分类交叉熵函数区别详解写在前面查了下百度,交叉熵,是度量两个分布间差异的概念。而在我们神经网络中,两个分布也就是y的真实值分布和预测值分布。当两个分布越接近时,其交叉熵值也就越小。根据上面知识,也就转化为我们需要解决让预测值和真实值尽可能接近的问题,而这正与概率论数理统计中的最大似然分布一脉相承,进而目标转化为确定值的分布和求解最大似然估计问题。二分类问题表示分类任务中有两个类别,比如我
二分类和多分类交叉熵函数区别详解
写在前面
查了下百度,交叉熵,是度量两个分布间差异的概念。而在我们神经网络中,两个分布也就是y的真实值分布和预测值分布。当两个分布越接近时,其交叉熵值也就越小。
根据上面知识,也就转化为我们需要解决让预测值和真实值尽可能接近的问题,而这正与概率论数理统计中的最大似然分布一脉相承,进而目标转化为确定值的分布和求解最大似然估计问题。
二分类问题
表示分类任务中有两个类别,比如我们想判断一张图片是不是猫。也就是说,训练一个分类器,输入一张图片,用特征向量x表示,输出是不是猫用y=0或1表示,其中1表示是,0表示不是。
这样的问题,我们完全可以用0-1分布来进行表示:
yiy_iyi | 1−yi1-y_i1−yi |
---|---|
yi^\hat{y_i}yi^ | 1−yi^1-\hat{y_i}1−yi^ |
注:其中yi为真实值,yi^\hat{y_i}yi^为预测值,且yiy_iyi的值为0或1
此时求解最大似然估计过程如下:
L(yi^)=Πi=1nyi^yi(1−yi^)1−yi L(\hat{y_i})=\Pi_{i=1}^{n}\hat{y_i}^{y_i}(1-\hat{y_i})^{1-y_i} L(yi^)=Πi=1nyi^yi(1−yi^)1−yi
两边同时取对数
log(L(yi^))=∑i=1n(yilog(yi^)+(1−yi)log(1−yi^)) log(L(\hat{y_i}))=\sum_{i=1}^{n}(y_ilog(\hat{y_i})+(1-y_i)log(1-\hat{y_i})) log(L(yi^))=i=1∑n(yilog(yi^)+(1−yi)log(1−yi^))
最大似然估计要求数越大越好,而损失函数要求越小越好,因而损失函数在前面加上负号,因而也得到了二分类问题使用的交叉熵损失函数。
Loss=−∑i=1n(yilog(yi^)+(1−yi)log(1−yi^)) Loss=-\sum_{i=1}^{n}(y_ilog(\hat{y_i})+(1-y_i)log(1-\hat{y_i})) Loss=−i=1∑n(yilog(yi^)+(1−yi)log(1−yi^))
多分类问题
表示分类任务有多个类别,如对一堆水果分类,它们可能是橘子、苹果、梨等,每个样本有且只有一个标签。
这种情况与二分类类似,只是可能的情况增多了,可以描述为一个离散分布
y1y_{1}y1 | y2y_2y2 | … | yky_kyk |
---|---|---|---|
y1^\hat{y_1}y1^ | y2^\hat{y_2}y2^ | … | yk^\hat{y_k}yk^ |
注:y1、y2...yky_1、y_2...y_ky1、y2...yk为真实值,其中有且只有一个为1,其余为0。(采用one-hot编码)
此时求解最大似然函数过程如下:
L(yi^)=Πi=1n(y(i,1)^y(i,1)y(i,2)^y(i,2)...y(i,n)^y(i,n)) L(\hat{y_i})=\Pi_{i=1}^{n}(\hat{y_{(i,1)}}^{y_{(i,1)}}\hat{y_{(i,2)}}^{y_{(i,2)}}...\hat{y_{(i,n)}}^{y_{(i,n)}}) L(yi^)=Πi=1n(y(i,1)^y(i,1)y(i,2)^y(i,2)...y(i,n)^y(i,n))
因为真实值只有一个为1,其余为0,因而只有1项值非零,可化简为:
L(yi^)=Πi=1ny(i,m)^y(i,m) L(\hat{y_i})=\Pi_{i=1}^{n}\hat{y_{(i,m)}}^{y_{(i,m)}} L(yi^)=Πi=1ny(i,m)^y(i,m)
注:y(i,m)^\hat{y_{(i,m)}}y(i,m)^表示含义为第i个样本,属于第m个类别(m值会随样本的变化动态改变)。
两边同时取对数:
log(L(yi^))=∑i=1ny(i,m)log(yi,m^) log(L(\hat{y_i}))=\sum_{i=1}^{n}y_{(i,m)}log(\hat{y_{i,m}}) log(L(yi^))=i=1∑ny(i,m)log(yi,m^)
与二元分类同理,此时多分类的交叉熵损失函数即为:
Loss=−∑i=1ny(i,m)log(yi,m^) Loss=-\sum_{i=1}^{n}y_{(i,m)}log(\hat{y_{i,m}}) Loss=−i=1∑ny(i,m)log(yi,m^)
参考文献
[1] https://www.bilibili.com/video/BV1a5411W7Dn?t=47
[2] https://juejin.cn/post/6844903630479294477
更多推荐
所有评论(0)