变分自编码器(VAE)学习

一、概率建模与数学推导

1. 隐变量模型与变分推断

VAE假设观测数据 x x x由隐变量 z z z生成,联合分布为:
p θ ( x , z ) = p θ ( x ∣ z ) p ( z ) p_\theta(x, z) = p_\theta(x|z)p(z) pθ(x,z)=pθ(xz)p(z)
其中 p ( z ) = N ( 0 , I ) p(z) = \mathcal{N}(0, I) p(z)=N(0,I)为隐变量先验分布, p θ ( x ∣ z ) p_\theta(x|z) pθ(xz)为解码器定义的条件分布。目标为最大化观测数据的对数似然:
log ⁡ p θ ( x ) = log ⁡ ∫ p θ ( x ∣ z ) p ( z ) d z \log p_\theta(x) = \log \int p_\theta(x|z)p(z)dz logpθ(x)=logpθ(xz)p(z)dz
由于积分难以直接计算,引入变分分布 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(zx)近似真实后验 p θ ( z ∣ x ) p_\theta(z|x) pθ(zx),通过优化变分下界(ELBO)替代:
log ⁡ p θ ( x ) ≥ ELBO = E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x ∣ z ) ] − KL ( q ϕ ( z ∣ x ) ∥ p ( z ) ) \log p_\theta(x) \geq \text{ELBO} = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \text{KL}(q_\phi(z|x) \| p(z)) logpθ(x)ELBO=Eqϕ(zx)[logpθ(xz)]KL(qϕ(zx)p(z))
推导过程:

  1. 全概率公式展开:
    log ⁡ p ( x ) = log ⁡ ∫ p ( x , z ) d z = log ⁡ ∫ q ( z ∣ x ) p ( x , z ) q ( z ∣ x ) d z \log p(x) = \log \int p(x,z)dz = \log \int q(z|x) \frac{p(x,z)}{q(z|x)} dz logp(x)=logp(x,z)dz=logq(zx)q(zx)p(x,z)dz

  2. Jensen不等式应用:
    对于凹函数 log ⁡ ( ⋅ ) \log(\cdot) log(),有:
    log ⁡ E q ( z ∣ x ) [ p ( x , z ) q ( z ∣ x ) ] ≥ E q ( z ∣ x ) [ log ⁡ p ( x , z ) q ( z ∣ x ) ] \log \mathbb{E}_{q(z|x)} \left[ \frac{p(x,z)}{q(z|x)} \right] \geq \mathbb{E}_{q(z|x)} \left[ \log \frac{p(x,z)}{q(z|x)} \right] logEq(zx)[q(zx)p(x,z)]Eq(zx)[logq(zx)p(x,z)]

  3. 分解联合分布:
    p ( x , z ) q ( z ∣ x ) = p ( x ∣ z ) p ( z ) q ( z ∣ x )    ⟹    log ⁡ p ( x , z ) q ( z ∣ x ) = log ⁡ p ( x ∣ z ) + log ⁡ p ( z ) q ( z ∣ x ) \frac{p(x,z)}{q(z|x)} = \frac{p(x|z)p(z)}{q(z|x)} \implies \log \frac{p(x,z)}{q(z|x)} = \log p(x|z) + \log \frac{p(z)}{q(z|x)} q(zx)p(x,z)=q(zx)p(xz)p(z)logq(zx)p(x,z)=logp(xz)+logq(zx)p(z)

  4. 分解ELBO:
    ELBO = E q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] ⏟ 重构项 − KL ( q ( z ∣ x ) ∥ p ( z ) ) ⏟ 正则项 \text{ELBO} = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{重构项}} - \underbrace{\text{KL}(q(z|x) \| p(z))}_{\text{正则项}} ELBO=重构项 Eq(zx)[logp(xz)]正则项 KL(q(zx)p(z))
    重构项要求生成数据接近输入,KL项约束潜在分布对齐先验。

2. KL散度的解析计算

假设 q ϕ ( z ∣ x ) = N ( μ , σ 2 I ) q_\phi(z|x) = \mathcal{N}(\mu, \sigma^2 I) qϕ(zx)=N(μ,σ2I),KL散度可解析为:
KL = 1 2 ∑ i = 1 d ( μ i 2 + σ i 2 − 1 − ln ⁡ σ i 2 ) \text{KL} = \frac{1}{2} \sum_{i=1}^d \left( \mu_i^2 + \sigma_i^2 - 1 - \ln \sigma_i^2 \right) KL=21i=1d(μi2+σi21lnσi2)
其中 d d d为潜在空间维度。推导基于高斯分布KL公式:
KL ( N ( μ , σ 2 ) ∥ N ( 0 , 1 ) ) = 1 2 ( μ 2 + σ 2 − ln ⁡ σ 2 − 1 ) \text{KL}(\mathcal{N}(\mu, \sigma^2) \| \mathcal{N}(0,1)) = \frac{1}{2}(\mu^2 + \sigma^2 - \ln \sigma^2 -1) KL(N(μ,σ2)N(0,1))=21(μ2+σ2lnσ21)


二、网络结构与实现细节

1. 编码器(Encoder)

• 输入:数据 x x x

• 输出:潜在分布的均值 μ \mu μ和对数方差 log ⁡ σ 2 \log \sigma^2 logσ2

• 网络设计(简化):

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
    
    def forward(self, x):
        h = F.relu(self.fc1(x))  # 非线性激活
        return self.fc_mu(h), self.fc_logvar(h)

技术细节:

• 使用ReLU激活增强非线性表达。

• 输出层无激活函数,允许 μ \mu μ log ⁡ σ 2 \log \sigma^2 logσ2自由取值。

2. 解码器(Decoder)

• 输入:潜在变量 z z z(通过重参数化采样)。

• 输出:重构数据 x ^ \hat{x} x^

• 网络设计(简化):

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid()  # 输出归一化至[0,1]
        )
    
    def forward(self, z):
        return self.fc(z)

激活函数选择:

• 二值数据使用Sigmoid,对应伯努利分布。

• 连续数据(如RGB图像)使用Tanh,对应高斯分布。

3. 重参数化技巧

将随机采样转化为确定性计算:
z = μ + ϵ ⊙ σ , ϵ ∼ N ( 0 , I ) z = \mu + \epsilon \odot \sigma, \quad \epsilon \sim \mathcal{N}(0, I) z=μ+ϵσ,ϵN(0,I)
代码实现:

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)  # 计算标准差
    eps = torch.randn_like(std)     # 从标准正态分布采样
    return mu + eps * std

作用:分离随机性与确定性参数,使梯度可传至编码器。


三、损失函数与优化

  • 总损失为负ELBO:
    L = BCE ( x ^ , x ) + β ⋅ KL ( q ( z ∣ x ) ∥ p ( z ) ) \mathcal{L} = \text{BCE}(\hat{x}, x) + \beta \cdot \text{KL}(q(z|x) \| p(z)) L=BCE(x^,x)+βKL(q(zx)p(z))
  • 重构损失(BCE或MSE):

损失函数分解
​​伯努利分布假设​​(如MNIST的二值像素数据),使用​​二元交叉熵(BCE)​​计算重构损失;
​​高斯分布假设​​(如连续RGB图像),使用​​均方误差(MSE)​​作为重构损失;

BCE = − ∑ [ x ln ⁡ x ^ + ( 1 − x ) ln ⁡ ( 1 − x ^ ) ] \text{BCE} = -\sum \left[ x \ln \hat{x} + (1-x)\ln(1-\hat{x}) \right] BCE=[xlnx^+(1x)ln(1x^)]

  • KL散度项:

KL = − 1 2 ∑ ( 1 + log ⁡ σ 2 − μ 2 − σ 2 ) \text{KL} = -\frac{1}{2} \sum (1 + \log \sigma^2 - \mu^2 - \sigma^2) KL=21(1+logσ2μ2σ2)
代码实现:

def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD

超参数调整:
β \beta β-VAE( β \beta β>1)增强潜在空间解耦性, β \beta β<1提升生成质量。


四、扩展模型与前沿改进

1. 条件VAE(CVAE)

• 改进点:在编码器/解码器输入中拼接条件信息(如类别标签)。

• 数学形式:

p θ ( x ∣ z , y ) = Decoder ( z , y ) , q ϕ ( z ∣ x , y ) = Encoder ( x , y ) p_\theta(x|z,y) = \text{Decoder}(z,y), \quad q_\phi(z|x,y) = \text{Encoder}(x,y) pθ(xz,y)=Decoder(z,y),qϕ(zx,y)=Encoder(x,y)
• 应用:可控生成(如生成指定数字的手写体)。

2. VQ-VAE

• 核心思想:使用离散码本替代连续高斯分布。

• 损失函数:

L = 重构损失 + ∥ sg ( z e ) − e k ∥ 2 + β ∥ z e − sg ( e k ) ∥ 2 \mathcal{L} = \text{重构损失} + \| \text{sg}(z_e) - e_k \|^2 + \beta \| z_e - \text{sg}(e_k) \|^2 L=重构损失+sg(ze)ek2+βzesg(ek)2
(sg表示停止梯度, e k e_k ek为码本向量)。

3. 层级VAE(HVAE)

• 结构:引入多层潜在变量 z 1 , z 2 , … , z L z_1, z_2, \dots, z_L z1,z2,,zL,生成过程为马尔可夫链。

• ELBO扩展:

ELBO = ∑ l = 1 L E q ( z 1 : L ∣ x ) [ log ⁡ p ( x ∣ z 1 : L ) ] − KL ( q ( z l ∣ x ) ∥ p ( z l ) ) \text{ELBO} = \sum_{l=1}^L \mathbb{E}_{q(z_{1:L}|x)}[\log p(x|z_{1:L})] - \text{KL}(q(z_l|x) \| p(z_l)) ELBO=l=1LEq(z1:Lx)[logp(xz1:L)]KL(q(zlx)p(zl))
用于高分辨率图像生成。

结论

VAE的原理与代码实现呈现“理论复杂、代码简洁”的鲜明对比。深度学习框架对数学细节的封装以及概率模型的模块化设计:

  • 代码可见:https://github.com/AntixK/PyTorch-VAE/
  • 变分下界ELBO的构建: M S E + K L MSE+KL MSE+KL
  • 通过网络层输出均值 μ \mu μ和对数方差 log ⁡ σ 2 \log \sigma^2 logσ2
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐