论文网址:ICLR Poster Rethinking Classifier Re-Training in Long-Tailed Recognition: Label Over-Smooth Can Balance

目录

1. 心得

2. 论文逐段精读

2.1. Abstract

2.2. Introduction

2.3. Related Work

2.4. Rethinking Classifier Re-Training

2.4.1. Revisiting Classifier Re-Training methods

2.4.2. Logits Magnitude

2.5. Method

2.5.1. Deep Dive into Logits Magnitude

2.5.2. Logits Retargeting

2.6. Experiments

2.6.1. Datasets and Implementation Details

2.6.2. Benchmark Results

2.6.3. Ablation Study

2.7. Conclusion

1. 心得

(1)如果没有数学基础要看很久,比较硬货

2. 论文逐段精读

2.1. Abstract

        ①先前的长尾工作将训练和分类分开

2.2. Introduction

        ①对于长尾问题,通常使用解耦方法。即先学习表征,再重学习分类

        ②VanillaCross-Entropy(CE)和作者提出的logits retargeting approach (LORT)在CIFAR 100-LT对于不同基数类别的正负例logits分布:

discernibility  n.辨别能力;分辨率;鉴别力

2.3. Related Work

        ①现有的长尾解决办法:重采样、损失加权、表征学习、分类器设计、解耦训练、数据增强

        ②举例一些解耦办法

2.4. Rethinking Classifier Re-Training

2.4.1. Revisiting Classifier Re-Training methods

        ①使用LTWB作为主干,在训练分类器的时候冻结主干

        ②对于K个类别,第i个类别有n_i个样本

        ③图像\mathbf{x}加工后的特征表示为f\left ( \mathbf{x} \right )

        ④分类器权重:\mathbf{W} = \left[ \mathbf{w}_1, \dots, \mathbf{w}_K \right]

        ⑤联合公式:

\mathcal{L}(\mathbf{W}; f(\mathbf{x}), y) = -r_w(y) \cdot \log\left( \frac{e^{z_y}}{\sum_i e^{z_i}} \right),

其中z_i = g_i\left( \mathbf{w}_i, f(\mathbf{x}) \right)是分类器计算出的logits,r_w(y)表示当前标签的重分配权重因子

        ⑥CIFAR100-LT数据集下不同分类方法的测试:

2.4.2. Logits Magnitude

        ①先把权重r_w(y)换为独热编码:

\mathcal{L}(\mathbf{W}, \mathbf{b}; f(\mathbf{x}), y) = \sum_i -y_i \cdot \log\left( \frac{e^{z_i}}{\sum_j e^{z_j}} \right)

然后让\mathbf{z} = \mathbf{W}^\top f(\mathbf{x}) + \mathbf{b}\mathbf{s}设为最终预测概率s_i = \frac{e^{z_i}}{\sum_j e^{z_j}}

        ②计算损失关于偏置\mathbf{b}的Hessian矩阵,并记为\mathbf{H}

\mathbf{H}_{ij} = \begin{cases} s_i(1 - s_i) & \text{if } i = j \\ -s_i s_j & \text{if } i \neq j \end{cases}

其中\mathbf{H}是个半正定矩阵,因为对于任意\mathbf{x}都有\mathbf{x}\mathbf{H}\mathbf{x}^\top \geq 0

        ③设正样本和负样本的logits分别为\mathbf{z}_{Pi}\mathbf{z}_{Ni},定义Logits Magnitude\mathbf{L}\in \mathbb{R}^{K}为正样本和负样本在样本i上的均值:

L_{\mathbf{i}} = \mathbb{E}\left[ z_{P_{\mathbf{i}}} \right] - \mathbb{E}\left[ z_{N_{\mathbf{i}}} \right]

        ④多数类样本和少数类样本的Logits均值和方差:

        ⑤当参数\mathbf{b}\mathbf{w}优化最优\mathbf{b}^*\mathbf{w}^*时,可以有一系列的收敛点\left( \mathbf{W}', \mathbf{b}^* \right),让\mathbf{W}'_i = \mathbf{W}^*_i + \varepsilon(这个替代需要证明,就是哪怕是\mathbf{W}'_i = \mathbf{W}^*_i + \varepsilon,求出来最后的s也是一样的):

s'_i = \frac{e^{z'_i}}{\sum_j e^{z'_j}} = \frac{1}{\sum_j e^{\mathbf{W}_j f(\mathbf{x}) + b^*_j - \mathbf{W}_i f(\mathbf{x}) - b^*_i}} = \frac{1}{\sum_j e^{z_j - z_i}} =s_i,

        ⑥将随机量\varepsilon的期望设置为0\mathbb{E}\left[ \varepsilon \right] = 0

\mathbb{E}\left[ \|\mathbf{W}'_i\|_2^2 \right] = \|\mathbf{W}^*_i\|_2^2 + \mathbb{E}\left[ \|\varepsilon\|_2^2 \right]\\ \mathbb{E}\left[ L'_{\mathbf{i}} \right] = L^*_{\mathbf{i}}

其中\|\mathbf{W}^*_i\|_2^2是确定数字所以期望等于自己,然后说明随意值\mathbf{W}'_i范数期望可以比原始值\mathbf{W}^*_i更大。

下面那个期望是通过:

\mathbf{L}_i = \mathbb{E}\left[ z_{P,i} \right] - \mathbb{E}\left[ z_{N,i} \right]

z'_{P,i} = z^*_{P,i} + c, \quad z'_{N,i} = z^*_{N,i} + c

\mathbf{L}'_i = \left( z^*_{P,i} + c \right) - \left( z^*_{N,i} + c \right) = z^*_{P,i} - z^*_{N,i} = \mathbf{L}^*_i

得到的

因此作者觉得通过Weight Norm改变模长可能是无效的

        ⑦模型对头类、中类、尾类的判别能力(Logits Magnitude)越平衡,整体准确率越高。(根据图1来看,我认为作者言下之意是对于每个大类的类,σ越小越好)

        ⑧logits magnitude对标准差的正则化:

\mathbf{r}_i = \frac{\sigma(z_i)}{\mathbf{L}_i} \in \mathbb{R}^{K}

2.5. Method

2.5.1. Deep Dive into Logits Magnitude

        ①作者发现每个类都得到了类似的logits magnitude:

L_1 \approx L_2 \approx \dots \approx L_M

        ②相对类别logits波动性:

\sigma(z_i) = \mathbf{r}_i L_i

        ③考虑一个独立影响每个类\mathbf{z}_i的类随机扰动\Delta _i,这个扰动和logits的标准差有关,如z'_i = z_i + \Delta_i\Delta_i \sim \xi_{\mathbf{r}, L_i},其中\xi是随机变量且期望为0

        ④扰动下概率s'_i的期望:

\mathbb{E}\left[ s'_i \right] = \mathbb{E}\left[ \frac{1}{\sum_j e^{z'_j - z'_i}} \right] = \mathbb{E}\left[ \frac{1}{\sum_j e^{(\Delta_j - \Delta_i)} e^{(z_j - z_i)}} \right]

在类别平衡的数据集里,正则化的标准差\mathbf{r}_1 \approx \mathbf{r}_2 \approx \dots \approx \mathbf{r}_M都几乎相似,且不同类别不正则化的标准差(\mathbf{r}_j L_j - \mathbf{r}_i L_i)接近0

        ⑤去掉了扰动的求和,将扰动单独提出来

\mathbb{E}\left[ s'_i \right] \approx \frac{1}{\mathbb{E}\left[ \exp(\Delta_j - \Delta_i) \right]} \mathbb{E}\left[ \frac{1}{\sum_j e^{z_j - z_i}} \right] = \frac{\mathbb{E}\left[ s_i \right]}{\mathbb{E}\left[ \exp(\Delta_j - \Delta_i) \right]}

其中\mathbb{E}\left[ \exp(\Delta_j - \Delta_i) \right]可以被视为一个常数\mathcal{E},对于均衡的数据集,扰动对所有类别的影响恒定:

\mathbb{E}\left[ s'_i \right] / \mathbb{E}\left[ s_i \right] \approx 1/\xi

对于类别不均衡的数据集,每个类j\exp(\Delta_j - \Delta_i)可能完全不一样。如果r_j < r_k(类k的正则化标准差大),\mathbb{E}\left[ \exp(\Delta_j - \Delta_i) \right] < \mathbb{E}\left[ \exp(\Delta_k - \Delta_i) \right](类k的扰动更大)。

        ⑥当\xi遵从正态分布,有:

\mathbb{E}\left[ \exp(\Delta_j - \Delta_i) \right] = \exp(\left( \sigma(\Delta_j - \Delta_i))^2 / 2 \right)

        ⑦作者需要改变logits分布,所以提出了logits retargeting approach (LORT)

2.5.2. Logits Retargeting

        ①LORT:

\mathcal{L}(\mathbf{w}, \mathbf{b}; f(\mathbf{x}), \mathbf{y}) = \sum_{i=1}^K -\hat{\mathbf{y}}_i \cdot \log\left( \frac{e^{z_i}}{\sum_j e^{z_j}} \right)

\hat{\mathbf{y}}_i = \begin{cases} 1-\delta + \delta/K & \text{if } i = y \\ \delta/K & \text{if } i \neq y \end{cases} \quad \text{and} \quad z_i = \mathbf{w}_i^\top f(\mathbf{x}) + b

其中\delta \in [0,1)是控制负类概率的常数,可以被称为标签平滑超参数

作者设计的LORT相当于一个标签概率平滑,如\delta=0时:

\hat{\mathbf{y}}_i = \begin{cases} 1 & \text{if } i = y \\ 0 & \text{if } i \neq y \end{cases}

其实就是正常的交叉熵;

\delta=0.2时:

\hat{\mathbf{y}}_i = \begin{cases} 1-0.2 + 0.2/K & \text{if } i = y \\ 0.2/K & \text{if } i \neq y \end{cases}

上面i=y的真标签就给自己分配稍低一点的概率,给所有别的不属于这个标签的类加一点概率;

如作者认为这个\delta可以大一点比如\delta=0.99

\hat{\mathbf{y}}_i = \begin{cases} 1-0.9 + 0.99/K & \text{if } i = y \\ 0.99/K & \text{if } i \neq y \end{cases}

这样真实类别就会为自己分配极少的概率

2.6. Experiments

       ①长尾图像分类数据集:CIFAR100-LT,ImageNet-LT,iNaturalist2018

       ②设备:GeForce RTX 3090 (24GB)和A100-PCIE (40GB)

2.6.1. Datasets and Implementation Details

        ①imbalanced ratio (IR)计算公式:

IR=n_{max}/n_{min}

其中n都只算训练集的

2.6.2. Benchmark Results

        ①实验:

        ②不同类别的Logit Magnitude和归一化标准导数:

2.6.3. Ablation Study

        ①标签平滑值实验:

        ②在不同学习率和weight decay上的不敏感性:

        ③对不同主干的提升:

reassuringly  adv.令人放心的;安慰地;鼓励地

malleable  adj.可塑的;有延展性的;可锻造的;易受影响(或改变)的;易成型的;可轧压的

2.7. Conclusion

        ~

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐