Triplet Loss的动机

一个好的特征提取器,应该尽可能的做到同类别样本映射出来的特征会聚集在一起,而不同类别的样本映射出来的特征应该要相互远离。

为了达到这个目标,Triplet Loss显式的在Loss里面要求:不同类别之间的距离至少要超过同类别之间距离的某个阈值。如果能够做到这一点,那么类内距和类间距之间差就有一个明显的鸿沟,那么也可以达到上面提到的目标。

Triplet Loss的定义

Triplet Loss里面包含若干三元组:

  • 锚点 anchor
  • 正例 positive
  • 负例 negative

要求:锚点和正例是处于相同的类别,锚点和负例处于不同的类别。

a,p,n都不是原始样本,而是原始样本被神经网络做特征提取后的得到的特征向量。即:a=f(xa),b=f(xb),c=f(xc)a=f(x_a), b=f(x_b), c=f(x_c)a=f(xa),b=f(xb),c=f(xc)f(⋅)f(·)f()是神经网络特征提取器。

对于一个三元组triplet (a,p,n),它的triplet loss写作:
L=max(d(a,p)−d(a,n)+margin,0)L=max(d(a,p)- d(a,n)+margin, 0)L=max(d(a,p)d(a,n)+margin,0),其中d(x,y)d(x,y)d(x,y) 是自定义的距离函数。

这个东西写的还是很直观的,它想表达的意思为:

  1. 如果 d(a,p)−d(a,n)+margin>0d(a,p)- d(a,n)+margin>0d(a,p)d(a,n)+margin>0,那么loss就是 d(a,p)−d(a,n)+margind(a,p)- d(a,n)+margind(a,p)d(a,n)+margin,否则就是0。0的时候就没有产生实际loss,就不会有梯度,意味着模型无需优化。
  2. d(a,p)−d(a,n)+margin>0d(a,p)- d(a,n)+margin>0d(a,p)d(a,n)+margin>0,有 d(a,n)−d(a,p)<margind(a,n) - d(a,p) < margind(a,n)d(a,p)<margin,此时锚点和负例之间的距离和锚点与正例之间的距离之差还没有超过阈值,于是就要会产生LOSS。
  3. 又因为优化的目标是让loss越小越好,于是模型就会千方百计的优化fff,使得 d(a,p)−d(a,n)+margind(a,p)−d(a,n)+margind(a,p)d(a,n)+margin 越小越好,直到d(a,p)−d(a,n)+margind(a,p)−d(a,n)+margind(a,p)d(a,n)+margin 小于等于0,就不优化了。

Triplet Loss在pytorch里面的实现

__author__ = 'dk'
import torch
import torch as th
from torch.nn import functional as F
from torch import nn

class TripletLoss(nn.Module):
    """Triplet loss with hard positive/negative mining.

    Reference:
        Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.

    Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.

    Args:
        margin (float, optional): margin for triplet. Default is 0.3.
    """

    def __init__(self, margin=0.3, batch_size=128, view_num=3):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)

    def forward(self, inputs,targets = None):
        """
        Args:
            inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).
            targets (torch.LongTensor): ground truth labels with shape (num_classes).
        """
        if targets == None:
            targets = self.targets
        n = inputs.size(0)

        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()
        dist.addmm_(1, -2, inputs, inputs.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability

        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)

        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        return self.ranking_loss(dist_an, dist_ap, y)

这是我在网上随便找到的衣服triplet loss实现,这个是基于l2范数实现的,写的比较隐蔽。
这个triplet loss是接受n个样本的特征作为inputs,然后返回最难优化的hard example。

我们解析这几句关键的,为了方便起见,记Inputs为VVV,里面的第i样本为 viv_ivi

        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)

计算输入的各个样本自己的l2范数。

        dist = dist + dist.t()

将每个样本自己的l2范数加上别人的l2范数。此时dist是个nxn的矩阵,dist[i,j]表示第i样本与第j样本的l2范数之和,也就是:dist[i,j]=vi2+vj2dist[i,j]=v^2_i + v^2_jdist[i,j]=vi2+vj2

        dist.addmm_(1, -2, inputs, inputs.t())

这个式子展开就是:A=dist−2×VVTA=dist -2 \times VV^TA=dist2×VVT
于是A[i,j]=vi2+vj2−2vivj=(vi−vj)2A[i,j]=v^2_i+v^2_j-2v_iv_j=(v_i-v_j)^2A[i,j]=vi2+vj22vivj=(vivj)2,妙啊。这就是等价于想把每个样本相互减,然后计算差值向量的l2范数。

        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability

上面这句话就是在开方,这没有啥好说的。

        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())

上面这句话的作用是:mask[i,j]mask[i,j]mask[i,j]表示i个样本和第j个样本的label是否相同,相同为True,不同为false, 注意这是一个bool的tensor矩阵。

        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))

dist[i][mask[i]] 表示把与第i样本具有相同标签那些样本的dist拿出来,max()是取最大的那个。
dist[i][mask[i] == 0].min()表示把与第i个样本不同标签的样本的dist拿出来,取距离最小的那个。

于是dist_ap保存了那些距离各自锚点最远的同类标签。
dist_an保存距离各自锚点最近的异类标签。这类样本都是难分的。这其实是个hardest triplet loss。

        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        return self.ranking_loss(dist_an, dist_ap, y)

这个ranking_loss的公式为 loss(x1,x2,y)=max⁡(0,−y∗(x1−x2)+margin)\text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin})loss(x1,x2,y)=max(0,y(x1x2)+margin)
代入进去就是:loss(dist_an,dist_ap,1)=max⁡(0,dist_ap−dist_an+margin)\text{loss}(dist\_{an}, dist\_{ap}, \bold{1}) = \max(0, dist\_ap-dist\_an + \text{margin})loss(dist_an,dist_ap,1)=max(0,dist_apdist_an+margin)

更通用的实现

那如果我们想实现一个可以选择LpL_pLp 范数也不是硬编码为2范数的triplet loss如何实现呢?

__author__ = 'dk'
import torch
import torch as th
from torch.nn import functional as F
from torch import nn

class TripletLoss(nn.Module):
    """Triplet loss with hard positive/negative mining.

    Reference:
        Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.

    Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.

    Args:
        margin (float, optional): margin for triplet. Default is 0.3.
    """

    def __init__(self, margin=0.3, batch_size=128, view_num=3, p=2):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.p = p
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)
        self.targets =  torch.cat([torch.arange(batch_size) for i in range(view_num)], dim=0)

    def forward(self, inputs,targets = None):
        """
        Args:
            inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).
            targets (torch.LongTensor): ground truth labels with shape (num_classes).
        """
        if targets == None:
            targets = self.targets
        n = inputs.size(0)

        # Compute pairwise distance, replace by the official when merged
        dist = []
        for i in range(n):
            dist.append(inputs[i] - inputs)
        dist = torch.stack(dist)
        dist = torch.linalg.norm(dist,ord=self.p,dim=2)

        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)

        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        return self.ranking_loss(dist_an, dist_ap, y)

两种方法用时测试:


if __name__ == '__main__':
    import time,tqdm
    inputs = torch.randn(128,256)
    step = 1000
    tripletloss1 = TripletLoss(batch_size=64,view_num=2)
    loss1 = 0
    from SSL.triplet_loss import TripletLoss as TripletLoss2
    tripletloss2 = TripletLoss2(batch_size=64,view_num=2)
    loss2 = 0
    s = time.time()
    for i in tqdm.trange(step):
        loss1 += tripletloss1(inputs)
    e1=time.time()
    for i in tqdm.trange(step):
        loss2 += tripletloss2(inputs)
    e2 =time.time()
    print('1: {0}s, result:{1}'.format(e1-s, loss1/ step))
    print('1: {0}s, result:{1}'.format(e2-e1, loss2/ step))

tripletloss1就是网上的方法,tripletloss2是我们实现的。结果:

方法1: 52.5317645072937s, result:2.367180824279785
方法2: 15.61170482635498s, result:2.367180824279785

可以发现我们的方法更快,究其原因,方法1里面矩阵的乘法及其耗时。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐