《从公式到代码:AIGC 实战里 BicycleGAN 的数学原理与实现》

数学原理

BicycleGAN 是一种多模态图像生成模型,核心思想是通过双向映射建立图像空间与隐空间的循环一致性。关键数学概念包括:

  1. 隐空间建模
    设输入图像为 $x$,输出图像为 $y$,隐变量 $z \sim \mathcal{N}(0,I)$。生成器定义为: $$ G: (x,z) \mapsto y $$ 编码器定义为: $$ E: y \mapsto z $$

  2. 双向循环一致性

    • 前向循环:$z \to y \to \hat{z}$
      $$ \mathcal{L}{\text{cyc-z}} = \mathbb{E}{z}[|E(G(x,z)) - z|_1] $$
    • 反向循环:$y \to z \to \hat{y}$
      $$ \mathcal{L}{\text{cyc-y}} = \mathbb{E}{y}[|G(x,E(y)) - y|_1] $$
  3. 对抗训练
    判别器 $D$ 的目标函数: $$ \min_G \max_D \mathcal{L}{\text{GAN}} = \mathbb{E}{y}[\log D(y)] + \mathbb{E}_{z}[\log(1 - D(G(x,z)))] $$

  4. 总损失函数
    $$ \mathcal{L}{\text{total}} = \mathcal{L}{\text{GAN}} + \lambda_1 \mathcal{L}{\text{cyc-z}} + \lambda_2 \mathcal{L}{\text{cyc-y}} $$ 其中 $\lambda_1,\lambda_2$ 为权重系数。


代码实现(PyTorch 框架)
import torch
import torch.nn as nn

# 生成器网络
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.InstanceNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, output_dim, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x, z):
        xz = torch.cat([x, z], dim=1)  # 拼接图像和隐变量
        return self.main(xz)

# 编码器网络
class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.Flatten(),
            nn.Linear(256*8*8, latent_dim)  # 输出隐变量
        )

    def forward(self, y):
        return self.main(y)

# 训练循环(伪代码)
def train():
    for real_x, real_y in dataloader:
        # 采样隐变量
        z = torch.randn(batch_size, latent_dim)
        
        # 生成图像
        fake_y = generator(real_x, z)
        
        # 双向重建
        z_recon = encoder(fake_y)  # 前向循环
        y_recon = generator(real_x, encoder(real_y))  # 反向循环
        
        # 计算损失
        loss_gan = adversarial_loss(discriminator(fake_y), real=True)
        loss_cyc_z = torch.mean(torch.abs(z_recon - z))  # $\mathcal{L}_{\text{cyc-z}}$
        loss_cyc_y = torch.mean(torch.abs(y_recon - real_y))  # $\mathcal{L}_{\text{cyc-y}}$
        total_loss = loss_gan + lambda1 * loss_cyc_z + lambda2 * loss_cyc_y
        
        # 反向传播更新
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()


关键实现细节
  1. 隐变量拼接
    生成器输入需将图像 $x$ 与隐变量 $z$ 在通道维度拼接,确保 $z$ 参与空间特征生成。

  2. 循环一致性权重
    经验值 $\lambda_1=10.0$, $\lambda_2=5.0$,需根据数据集调整:

    lambda1, lambda2 = 10.0, 5.0  # 循环损失权重
    

  3. 模式崩溃预防
    通过隐空间高斯采样保证输出多样性:

    z = torch.randn(batch_size, latent_dim, 1, 1)  # 4D张量对齐卷积维度
    

  4. 判别器设计
    推荐使用 PatchGAN 结构,对图像局部区域进行真伪判别,提升细节生成质量。

此实现完整呈现了从数学公式到代码的转化过程,通过双向循环约束确保隐变量与生成图像的双射关系,实现多模态图像生成。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐