【学习笔记】Gumbel softmax 学习文档
Gumbel-Softmax 是一种用于离散变量的连续近似方法。它可以将离散的分类变量(如 one-hot 编码)转换为连续的概率分布,同时保留可微性,适用于神经网络的反向传播。
·
Gumbel-Softmax 是一种用于离散变量的连续近似方法。它可以将离散的分类变量(如 one-hot 编码)转换为连续的概率分布,同时保留可微性,适用于神经网络的反向传播。
函数签名
torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, dim=-1)
参数说明
-
logits:- 类型:
Tensor - 描述:输入的未归一化分数(logits),通常是模型的输出。
- 形状:任意形状的张量,通常为
[batch_size, num_classes]。
- 类型:
-
tau:- 类型:
float - 描述:温度参数,控制分布的平滑程度。
- 较大的
tau会使分布更平滑,类别之间的概率差异较小。 - 较小的
tau会使分布更尖锐,接近 one-hot 编码。
- 较大的
- 默认值:
1.0
- 类型:
-
hard:- 类型:
bool - 描述:是否返回离散的 one-hot 编码。
- 如果为
True,输出是离散的 one-hot 编码,但梯度仍然通过连续的 softmax 输出计算(直通估计器)。 - 如果为
False,输出是连续的概率分布。
- 如果为
- 默认值:
False
- 类型:
-
dim:- 类型:
int - 描述:指定 softmax 操作的维度,通常是类别所在的维度。
- 默认值:
-1(最后一个维度)。
- 类型:
返回值
- 返回一个与
logits形状相同的张量。- 如果
hard=False,输出是连续的概率分布。 - 如果
hard=True,输出是离散的 one-hot 编码。
- 如果
使用示例
1. 基本用法
import torch
import torch.nn.functional as F
# 输入 logits
logits = torch.tensor([[2.0, 0.5, 0.1], [0.3, 1.2, 0.8]])
# 使用 Gumbel-Softmax
output = F.gumbel_softmax(logits, tau=1, hard=False, dim=-1)
print("连续采样结果:")
print(output)
2. 使用不同的温度参数
# 高温度(更平滑的分布)
output_high_tau = F.gumbel_softmax(logits, tau=5, hard=False, dim=-1)
# 低温度(更接近 one-hot 的分布)
output_low_tau = F.gumbel_softmax(logits, tau=0.5, hard=False, dim=-1)
print("高温度采样结果:")
print(output_high_tau)
print("低温度采样结果:")
print(output_low_tau)
3. 离散采样(hard=True)
# 离散采样(输出 one-hot 编码)
output_hard = F.gumbel_softmax(logits, tau=1, hard=True, dim=-1)
print("离散采样结果(one-hot 编码):")
print(output_hard)
Gumbel-Softmax 的工作原理
-
Gumbel 噪声:
- 为了模拟离散采样的随机性,Gumbel-Softmax 在 logits 上添加了 Gumbel 噪声。
- Gumbel 噪声的公式为:
g = − log ( − log ( U ) ) g = -\log(-\log(U)) g=−log(−log(U))
其中 ( U ) 是从均匀分布 ( U(0, 1) ) 中采样的随机数。
-
Softmax 转换:
- 添加 Gumbel 噪声后,通过 softmax 函数将 logits 转换为概率分布:
y i = exp ( ( log ( π i ) + g i ) / τ ) ∑ j exp ( ( log ( π j ) + g j ) / τ ) y_i = \frac{\exp((\log(\pi_i) + g_i) / \tau)}{\sum_j \exp((\log(\pi_j) + g_j) / \tau)} yi=∑jexp((log(πj)+gj)/τ)exp((log(πi)+gi)/τ)
- 添加 Gumbel 噪声后,通过 softmax 函数将 logits 转换为概率分布:
-
温度参数 ( \tau ):
- 控制分布的平滑程度:
- 当 ( ( τ ) ( \tau) (τ) → \to → 0 ) 时,分布接近离散的 one-hot 编码。
- 当 ( ( τ ) ( \tau) (τ) → \to → ∞ \infty ∞ ) 时,分布趋于均匀。
- 控制分布的平滑程度:
-
直通估计器(Straight-Through Estimator):
- 当
hard=True时,前向传播中输出离散的 one-hot 编码,但反向传播时仍然使用连续的 softmax 输出计算梯度。
- 当
注意事项
-
梯度连续性:
- Gumbel-Softmax 的输出是连续的概率分布,因此可以通过 PyTorch 的自动微分机制计算梯度。
- 即使
hard=True,梯度仍然通过连续的 softmax 输出计算。
-
温度选择:
- 温度 ( τ ) ( \tau) (τ) 是一个超参数,需要根据任务进行调整。
- 较小的 ( τ ) ( \tau) (τ) 会导致分布更尖锐,但可能导致梯度消失问题。
-
随机性:
- 每次调用 Gumbel-Softmax 都会引入随机性(来自 Gumbel 噪声),因此结果可能不同。
应用场景
- 生成模型:
- 在变分自编码器(VAE)中,用于对离散变量进行连续化近似。
- 强化学习:
- 用于对策略分布进行连续采样。
- 分类任务:
- 在需要对分类变量进行连续化处理的场景中使用。
参考文档
更多推荐

所有评论(0)