Pytorch nn.KLDivLoss, reduction=‘none’|‘mean’|'batchmean’详解

先看下官方文档

官方文档
https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
直接看官方文档即可。官方文档比本文讲解的更清楚。
在这里插入图片描述

然后运行下这个例子

import torch
import torch.nn as nn

if __name__ == '__main__':
    x = torch.tensor([[0.1,0.3,0.6],
                      [0.2,0.4,0.4]])
    y = torch.tensor([[0.3,0.2,0.5],
                      [0.2,0.7,0.1]])
    lxy_batchmean = nn.KLDivLoss(reduction = 'batchmean')(x,y)
    lxy_mean = nn.KLDivLoss(reduction='mean')(x,y)
    lxy_none = nn.KLDivLoss(reduction='none')(x,y)
    print(torch.sum(torch.sum(lxy_none,dim=1)) /2 )
    print(lxy_batchmean)
    print(torch.sum(lxy_none) / (2*3))
    print(lxy_mean)
    print(lxy_none)

output:

tensor(-1.2907)
tensor(-1.2907)
tensor(-0.4302)
tensor(-0.4302)
tensor([[-0.3912, -0.3819, -0.6466],
        [-0.3619, -0.5297, -0.2703]])

结论: batchmean是正确的KLDiv的计算方式。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐