Pytorch nn.KLDivLoss, reduction=‘none‘|‘mean‘|‘batchmean‘详解
Pytorch nn.KLDivLoss, reduction=‘none’|‘mean’|'batchmean’详解先看下官方文档https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html然后运行下这个例子import torchimport torch.nn as nnif __name__ == '__main__':x
·
Pytorch nn.KLDivLoss, reduction=‘none’|‘mean’|'batchmean’详解
先看下官方文档
官方文档
https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
直接看官方文档即可。官方文档比本文讲解的更清楚。
然后运行下这个例子
import torch
import torch.nn as nn
if __name__ == '__main__':
x = torch.tensor([[0.1,0.3,0.6],
[0.2,0.4,0.4]])
y = torch.tensor([[0.3,0.2,0.5],
[0.2,0.7,0.1]])
lxy_batchmean = nn.KLDivLoss(reduction = 'batchmean')(x,y)
lxy_mean = nn.KLDivLoss(reduction='mean')(x,y)
lxy_none = nn.KLDivLoss(reduction='none')(x,y)
print(torch.sum(torch.sum(lxy_none,dim=1)) /2 )
print(lxy_batchmean)
print(torch.sum(lxy_none) / (2*3))
print(lxy_mean)
print(lxy_none)
output:
tensor(-1.2907)
tensor(-1.2907)
tensor(-0.4302)
tensor(-0.4302)
tensor([[-0.3912, -0.3819, -0.6466],
[-0.3619, -0.5297, -0.2703]])
结论: batchmean是正确的KLDiv的计算方式。
更多推荐
所有评论(0)