Gumbel-Softmax 是一种用于离散变量的连续近似方法。它可以将离散的分类变量(如 one-hot 编码)转换为连续的概率分布,同时保留可微性,适用于神经网络的反向传播。

函数签名

torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, dim=-1)

参数说明

  1. logits:

    • 类型:Tensor
    • 描述:输入的未归一化分数(logits),通常是模型的输出。
    • 形状:任意形状的张量,通常为 [batch_size, num_classes]
  2. tau:

    • 类型:float
    • 描述:温度参数,控制分布的平滑程度。
      • 较大的 tau 会使分布更平滑,类别之间的概率差异较小。
      • 较小的 tau 会使分布更尖锐,接近 one-hot 编码。
    • 默认值:1.0
  3. hard:

    • 类型:bool
    • 描述:是否返回离散的 one-hot 编码。
      • 如果为 True,输出是离散的 one-hot 编码,但梯度仍然通过连续的 softmax 输出计算(直通估计器)。
      • 如果为 False,输出是连续的概率分布。
    • 默认值:False
  4. dim:

    • 类型:int
    • 描述:指定 softmax 操作的维度,通常是类别所在的维度。
    • 默认值:-1(最后一个维度)。

返回值

  • 返回一个与 logits 形状相同的张量。
    • 如果 hard=False,输出是连续的概率分布。
    • 如果 hard=True,输出是离散的 one-hot 编码。

使用示例

1. 基本用法
import torch
import torch.nn.functional as F

# 输入 logits
logits = torch.tensor([[2.0, 0.5, 0.1], [0.3, 1.2, 0.8]])

# 使用 Gumbel-Softmax
output = F.gumbel_softmax(logits, tau=1, hard=False, dim=-1)

print("连续采样结果:")
print(output)
2. 使用不同的温度参数
# 高温度(更平滑的分布)
output_high_tau = F.gumbel_softmax(logits, tau=5, hard=False, dim=-1)

# 低温度(更接近 one-hot 的分布)
output_low_tau = F.gumbel_softmax(logits, tau=0.5, hard=False, dim=-1)

print("高温度采样结果:")
print(output_high_tau)

print("低温度采样结果:")
print(output_low_tau)
3. 离散采样(hard=True)
# 离散采样(输出 one-hot 编码)
output_hard = F.gumbel_softmax(logits, tau=1, hard=True, dim=-1)

print("离散采样结果(one-hot 编码):")
print(output_hard)

Gumbel-Softmax 的工作原理

  1. Gumbel 噪声

    • 为了模拟离散采样的随机性,Gumbel-Softmax 在 logits 上添加了 Gumbel 噪声。
    • Gumbel 噪声的公式为:
      g = − log ⁡ ( − log ⁡ ( U ) ) g = -\log(-\log(U)) g=log(log(U))
      其中 ( U ) 是从均匀分布 ( U(0, 1) ) 中采样的随机数。
  2. Softmax 转换

    • 添加 Gumbel 噪声后,通过 softmax 函数将 logits 转换为概率分布:
      y i = exp ⁡ ( ( log ⁡ ( π i ) + g i ) / τ ) ∑ j exp ⁡ ( ( log ⁡ ( π j ) + g j ) / τ ) y_i = \frac{\exp((\log(\pi_i) + g_i) / \tau)}{\sum_j \exp((\log(\pi_j) + g_j) / \tau)} yi=jexp((log(πj)+gj)/τ)exp((log(πi)+gi)/τ)
  3. 温度参数 ( \tau )

    • 控制分布的平滑程度:
      • 当 ( ( τ ) ( \tau) (τ) → \to 0 ) 时,分布接近离散的 one-hot 编码。
      • 当 ( ( τ ) ( \tau) (τ) → \to ∞ \infty ) 时,分布趋于均匀。
  4. 直通估计器(Straight-Through Estimator)

    • hard=True 时,前向传播中输出离散的 one-hot 编码,但反向传播时仍然使用连续的 softmax 输出计算梯度。

注意事项

  1. 梯度连续性

    • Gumbel-Softmax 的输出是连续的概率分布,因此可以通过 PyTorch 的自动微分机制计算梯度。
    • 即使 hard=True,梯度仍然通过连续的 softmax 输出计算。
  2. 温度选择

    • 温度 ( τ ) ( \tau) (τ) 是一个超参数,需要根据任务进行调整。
    • 较小的 ( τ ) ( \tau) (τ) 会导致分布更尖锐,但可能导致梯度消失问题。
  3. 随机性

    • 每次调用 Gumbel-Softmax 都会引入随机性(来自 Gumbel 噪声),因此结果可能不同。

应用场景

  1. 生成模型
    • 在变分自编码器(VAE)中,用于对离散变量进行连续化近似。
  2. 强化学习
    • 用于对策略分布进行连续采样。
  3. 分类任务
    • 在需要对分类变量进行连续化处理的场景中使用。

参考文档

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐