变分自编码器(VAE)学习
VAE笔记
变分自编码器(VAE)学习
一、概率建模与数学推导
1. 隐变量模型与变分推断
VAE假设观测数据 x x x由隐变量 z z z生成,联合分布为:
p θ ( x , z ) = p θ ( x ∣ z ) p ( z ) p_\theta(x, z) = p_\theta(x|z)p(z) pθ(x,z)=pθ(x∣z)p(z)
其中 p ( z ) = N ( 0 , I ) p(z) = \mathcal{N}(0, I) p(z)=N(0,I)为隐变量先验分布, p θ ( x ∣ z ) p_\theta(x|z) pθ(x∣z)为解码器定义的条件分布。目标为最大化观测数据的对数似然:
log p θ ( x ) = log ∫ p θ ( x ∣ z ) p ( z ) d z \log p_\theta(x) = \log \int p_\theta(x|z)p(z)dz logpθ(x)=log∫pθ(x∣z)p(z)dz
由于积分难以直接计算,引入变分分布 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(z∣x)近似真实后验 p θ ( z ∣ x ) p_\theta(z|x) pθ(z∣x),通过优化变分下界(ELBO)替代:
log p θ ( x ) ≥ ELBO = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] − KL ( q ϕ ( z ∣ x ) ∥ p ( z ) ) \log p_\theta(x) \geq \text{ELBO} = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \text{KL}(q_\phi(z|x) \| p(z)) logpθ(x)≥ELBO=Eqϕ(z∣x)[logpθ(x∣z)]−KL(qϕ(z∣x)∥p(z))
推导过程:
-
全概率公式展开:
log p ( x ) = log ∫ p ( x , z ) d z = log ∫ q ( z ∣ x ) p ( x , z ) q ( z ∣ x ) d z \log p(x) = \log \int p(x,z)dz = \log \int q(z|x) \frac{p(x,z)}{q(z|x)} dz logp(x)=log∫p(x,z)dz=log∫q(z∣x)q(z∣x)p(x,z)dz -
Jensen不等式应用:
对于凹函数 log ( ⋅ ) \log(\cdot) log(⋅),有:
log E q ( z ∣ x ) [ p ( x , z ) q ( z ∣ x ) ] ≥ E q ( z ∣ x ) [ log p ( x , z ) q ( z ∣ x ) ] \log \mathbb{E}_{q(z|x)} \left[ \frac{p(x,z)}{q(z|x)} \right] \geq \mathbb{E}_{q(z|x)} \left[ \log \frac{p(x,z)}{q(z|x)} \right] logEq(z∣x)[q(z∣x)p(x,z)]≥Eq(z∣x)[logq(z∣x)p(x,z)] -
分解联合分布:
p ( x , z ) q ( z ∣ x ) = p ( x ∣ z ) p ( z ) q ( z ∣ x ) ⟹ log p ( x , z ) q ( z ∣ x ) = log p ( x ∣ z ) + log p ( z ) q ( z ∣ x ) \frac{p(x,z)}{q(z|x)} = \frac{p(x|z)p(z)}{q(z|x)} \implies \log \frac{p(x,z)}{q(z|x)} = \log p(x|z) + \log \frac{p(z)}{q(z|x)} q(z∣x)p(x,z)=q(z∣x)p(x∣z)p(z)⟹logq(z∣x)p(x,z)=logp(x∣z)+logq(z∣x)p(z) -
分解ELBO:
ELBO = E q ( z ∣ x ) [ log p ( x ∣ z ) ] ⏟ 重构项 − KL ( q ( z ∣ x ) ∥ p ( z ) ) ⏟ 正则项 \text{ELBO} = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{重构项}} - \underbrace{\text{KL}(q(z|x) \| p(z))}_{\text{正则项}} ELBO=重构项 Eq(z∣x)[logp(x∣z)]−正则项 KL(q(z∣x)∥p(z))
重构项要求生成数据接近输入,KL项约束潜在分布对齐先验。
2. KL散度的解析计算
假设 q ϕ ( z ∣ x ) = N ( μ , σ 2 I ) q_\phi(z|x) = \mathcal{N}(\mu, \sigma^2 I) qϕ(z∣x)=N(μ,σ2I),KL散度可解析为:
KL = 1 2 ∑ i = 1 d ( μ i 2 + σ i 2 − 1 − ln σ i 2 ) \text{KL} = \frac{1}{2} \sum_{i=1}^d \left( \mu_i^2 + \sigma_i^2 - 1 - \ln \sigma_i^2 \right) KL=21i=1∑d(μi2+σi2−1−lnσi2)
其中 d d d为潜在空间维度。推导基于高斯分布KL公式:
KL ( N ( μ , σ 2 ) ∥ N ( 0 , 1 ) ) = 1 2 ( μ 2 + σ 2 − ln σ 2 − 1 ) \text{KL}(\mathcal{N}(\mu, \sigma^2) \| \mathcal{N}(0,1)) = \frac{1}{2}(\mu^2 + \sigma^2 - \ln \sigma^2 -1) KL(N(μ,σ2)∥N(0,1))=21(μ2+σ2−lnσ2−1)
二、网络结构与实现细节
1. 编码器(Encoder)
• 输入:数据 x x x。
• 输出:潜在分布的均值 μ \mu μ和对数方差 log σ 2 \log \sigma^2 logσ2。
• 网络设计(简化):
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
def forward(self, x):
h = F.relu(self.fc1(x)) # 非线性激活
return self.fc_mu(h), self.fc_logvar(h)
技术细节:
• 使用ReLU激活增强非线性表达。
• 输出层无激活函数,允许 μ \mu μ和 log σ 2 \log \sigma^2 logσ2自由取值。
2. 解码器(Decoder)
• 输入:潜在变量 z z z(通过重参数化采样)。
• 输出:重构数据 x ^ \hat{x} x^。
• 网络设计(简化):
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.Sigmoid() # 输出归一化至[0,1]
)
def forward(self, z):
return self.fc(z)
激活函数选择:
• 二值数据使用Sigmoid,对应伯努利分布。
• 连续数据(如RGB图像)使用Tanh,对应高斯分布。
3. 重参数化技巧
将随机采样转化为确定性计算:
z = μ + ϵ ⊙ σ , ϵ ∼ N ( 0 , I ) z = \mu + \epsilon \odot \sigma, \quad \epsilon \sim \mathcal{N}(0, I) z=μ+ϵ⊙σ,ϵ∼N(0,I)
代码实现:
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar) # 计算标准差
eps = torch.randn_like(std) # 从标准正态分布采样
return mu + eps * std
作用:分离随机性与确定性参数,使梯度可传至编码器。
三、损失函数与优化
- 总损失为负ELBO:
L = BCE ( x ^ , x ) + β ⋅ KL ( q ( z ∣ x ) ∥ p ( z ) ) \mathcal{L} = \text{BCE}(\hat{x}, x) + \beta \cdot \text{KL}(q(z|x) \| p(z)) L=BCE(x^,x)+β⋅KL(q(z∣x)∥p(z)) - 重构损失(BCE或MSE):
损失函数分解
伯努利分布假设(如MNIST的二值像素数据),使用二元交叉熵(BCE)计算重构损失;
高斯分布假设(如连续RGB图像),使用均方误差(MSE)作为重构损失;
BCE = − ∑ [ x ln x ^ + ( 1 − x ) ln ( 1 − x ^ ) ] \text{BCE} = -\sum \left[ x \ln \hat{x} + (1-x)\ln(1-\hat{x}) \right] BCE=−∑[xlnx^+(1−x)ln(1−x^)]
- KL散度项:
KL = − 1 2 ∑ ( 1 + log σ 2 − μ 2 − σ 2 ) \text{KL} = -\frac{1}{2} \sum (1 + \log \sigma^2 - \mu^2 - \sigma^2) KL=−21∑(1+logσ2−μ2−σ2)
代码实现:
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + beta * KLD
超参数调整:
• β \beta β-VAE( β \beta β>1)增强潜在空间解耦性, β \beta β<1提升生成质量。
四、扩展模型与前沿改进
1. 条件VAE(CVAE)
• 改进点:在编码器/解码器输入中拼接条件信息(如类别标签)。
• 数学形式:
p θ ( x ∣ z , y ) = Decoder ( z , y ) , q ϕ ( z ∣ x , y ) = Encoder ( x , y ) p_\theta(x|z,y) = \text{Decoder}(z,y), \quad q_\phi(z|x,y) = \text{Encoder}(x,y) pθ(x∣z,y)=Decoder(z,y),qϕ(z∣x,y)=Encoder(x,y)
• 应用:可控生成(如生成指定数字的手写体)。
2. VQ-VAE
• 核心思想:使用离散码本替代连续高斯分布。
• 损失函数:
L = 重构损失 + ∥ sg ( z e ) − e k ∥ 2 + β ∥ z e − sg ( e k ) ∥ 2 \mathcal{L} = \text{重构损失} + \| \text{sg}(z_e) - e_k \|^2 + \beta \| z_e - \text{sg}(e_k) \|^2 L=重构损失+∥sg(ze)−ek∥2+β∥ze−sg(ek)∥2
(sg表示停止梯度, e k e_k ek为码本向量)。
3. 层级VAE(HVAE)
• 结构:引入多层潜在变量 z 1 , z 2 , … , z L z_1, z_2, \dots, z_L z1,z2,…,zL,生成过程为马尔可夫链。
• ELBO扩展:
ELBO = ∑ l = 1 L E q ( z 1 : L ∣ x ) [ log p ( x ∣ z 1 : L ) ] − KL ( q ( z l ∣ x ) ∥ p ( z l ) ) \text{ELBO} = \sum_{l=1}^L \mathbb{E}_{q(z_{1:L}|x)}[\log p(x|z_{1:L})] - \text{KL}(q(z_l|x) \| p(z_l)) ELBO=l=1∑LEq(z1:L∣x)[logp(x∣z1:L)]−KL(q(zl∣x)∥p(zl))
用于高分辨率图像生成。
结论
VAE的原理与代码实现呈现“理论复杂、代码简洁”的鲜明对比。深度学习框架对数学细节的封装以及概率模型的模块化设计:
- 代码可见:https://github.com/AntixK/PyTorch-VAE/
- 变分下界ELBO的构建: M S E + K L MSE+KL MSE+KL
- 通过网络层输出均值 μ \mu μ和对数方差 log σ 2 \log \sigma^2 logσ2
更多推荐
所有评论(0)