第G6周:CycleGAN实战
CycleGAN (Cycle-Consistent Adversarial Networks) 是一种用于无配对图像到图像转换 (Unpaired Image-to-Image Translation) 的深度学习模型。核心突破 :传统的图像转换(如 Pix2Pix)需要成对的训练数据(例如同一位置的白天和黑夜照片)。CycleGAN 不需要一一对应的配对数据,只需要两个不同域的数据集(例如一堆
·
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
1. CycleGAN 简介
CycleGAN (Cycle-Consistent Adversarial Networks) 是一种用于无配对图像到图像转换 (Unpaired Image-to-Image Translation) 的深度学习模型。
- 核心突破 :传统的图像转换(如 Pix2Pix)需要成对的训练数据(例如同一位置的白天和黑夜照片)。CycleGAN 不需要一一对应的配对数据,只需要两个不同域的数据集(例如一堆莫奈的画和一堆风景照片),就能学习它们之间的风格转换。
- 应用场景 :风格迁移(照片转油画)、季节转换(夏转冬)、物体转换(马转斑马)等。
2. 核心原理
CycleGAN 通过同时训练两个生成器和两个判别器,并引入 循环一致性损失 (Cycle Consistency Loss) 来保证转换质量。
2.1 模型架构
- 两个生成器 (Generators) :
G A B : 将域 A(如莫奈画)转换为域 B(如照片)。
G B A : 将域 B(如照片)转换为域 A(如莫奈画)。 - 两个判别器 (Discriminators) :
D A : 判断图像是真实的 A 还是由 [ o bj ec tO bj ec t ] G B A 生成的伪造 A。
D B : 判断图像是真实的 B 还是由 [ o bj ec tO bj ec t ] G A B 生成的伪造 B。
2.2 关键损失函数
- 对抗损失 (Adversarial Loss) :
让生成器生成的图像尽可能骗过判别器,使生成的图像看起来逼真。 - 循环一致性损失 (Cycle Consistency Loss) :
- 核心思想 :如果把一张图片转换过去再转换回来,应该变回原样。
- 公式 : [ o bj ec tO bj ec t ] A → G A B ( A ) → G B A ( G A B ( A )) ≈ A
- 作用 :防止模型随意改变图像的几何结构和内容(例如,把照片变成油画时,树的位置和形状不能变)。
- 身份损失 (Identity Loss) (可选):
- 如果把域 A 的图片输入给 [ o bj ec tO bj ec t ] G B A (本该把 B 转 A 的生成器),输出应该保持不变。这有助于保持色彩构成的稳定性。
3. 本项目代码结构分析
项目位于 cyclegan/ 目录下,基于 PyTorch 实现。
📂 cyclegan.py (主训练脚本)
项目的入口文件,负责训练流程的控制。
- 参数配置 :
- –n_epochs : 总训练轮数(默认 200)。
- –batch_size : 批次大小(CycleGAN 通常设为 1)。
- –lambda_cyc : 循环一致性损失权重(默认 10.0,权重很大,说明结构保持很重要)。
- 优化器 :使用 Adam 优化器。
- 学习率调度 :前 100 epoch 保持不变,后 100 epoch 线性衰减至 0。
📂 models.py (网络模型)
- 生成器 (GeneratorResNet) :
- 基于 ResNet (残差网络) 架构。
- 包含下采样层、9 个残差块 (Residual Blocks) 和上采样层。
- 相比 UNet,ResNet 在这种非配对转换任务中能更好地保留深层特征。
- 判别器 (Discriminator) :
- 使用 PatchGAN 架构。
- 输出不是单一的 True/False,而是一个 [ o bj ec tO bj ec t ] N × N 的矩阵,对图像的每个 [ o bj ec tO bj ec t ] 70 × 70 区域进行真假判别,能生成更清晰的纹理细节。
📂 datasets.py (数据处理)
- ImageDataset 类 :
- 非对齐加载 :设置 unaligned=True ,训练时随机从 A 文件夹和 B 文件夹取图,不需要配对。
- 数据增强 :包含 Resize (放大)、RandomCrop (随机裁剪) 和 RandomHorizontalFlip (水平翻转),增加模型鲁棒性。
📂 utils.py (工具辅助)
- ReplayBuffer (经验回放) :
- 维护一个缓冲区,存储过去生成的 50 张假图。
- 训练判别器时,不仅使用当前生成的假图,还随机从缓冲区采样历史假图。这能防止生成器和判别器之间的震荡,使训练更稳定。
4. 运行方法
在终端中进入 cyclegan 目录并运行脚本:
cd
d:\my_project\python\365\Advanced_Camp\G6_CycleGAN_practi
ce\cyclegan
python cyclegan.py
- 训练过程 :程序会实时打印 D loss (判别器损失) 和 G loss (生成器损失)。
- 结果查看 :
- 生成的样本图片保存在项目根目录的 images/monet2photo/ 下。
- 模型权重保存在 saved_models/monet2photo/ 下。
我的文件结构如下:

我的文件目录是这样的。
cyclegan.py程序如下:
import argparse
import itertools
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from models import *
from datasets import *
from utils import *
import torch
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="monet2photo", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model checkpoints")
parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")
parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight")
parser.add_argument("--lambda_id", type=float, default=5.0, help="identity loss weight")
opt = parser.parse_args()
print(opt)
# Get root path
root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Create sample and checkpoint directories
os.makedirs(os.path.join(root_path, "images/%s" % opt.dataset_name), exist_ok=True)
os.makedirs(os.path.join(root_path, "saved_models/%s" % opt.dataset_name), exist_ok=True)
# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
cuda = torch.cuda.is_available()
input_shape = (opt.channels, opt.img_height, opt.img_width)
# 初始化生成器鉴别器
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)
if cuda:
G_AB = G_AB.cuda()
G_BA = G_BA.cuda()
D_A = D_A.cuda()
D_B = D_B.cuda()
criterion_GAN.cuda()
criterion_cycle.cuda()
criterion_identity.cuda()
if opt.epoch != 0:
# 加载预训练模型
G_AB.load_state_dict(torch.load(os.path.join(root_path, "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch))))
G_BA.load_state_dict(torch.load(os.path.join(root_path, "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch))))
D_A.load_state_dict(torch.load(os.path.join(root_path, "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch))))
D_B.load_state_dict(torch.load(os.path.join(root_path, "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch))))
else:
# 初始化权重
G_AB.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)
# Optimizers
optimizer_G = torch.optim.Adam(
itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
# Image transformations
transforms_ = [
transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((opt.img_height, opt.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Training data loader
dataloader = DataLoader(
ImageDataset(os.path.join(root_path, "data/%s/" % opt.dataset_name), transforms_=transforms_, unaligned=True),
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
)
# Test data loader
val_dataloader = DataLoader(
ImageDataset(os.path.join(root_path, "data/%s/" % opt.dataset_name), transforms_=transforms_, unaligned=True, mode="test"),
batch_size=5,
shuffle=True,
num_workers=1,
)
def sample_images(batches_done):
"""Saves a generated sample from the test set"""
imgs = next(iter(val_dataloader))
G_AB.eval()
G_BA.eval()
real_A = Variable(imgs["A"].type(Tensor))
fake_B = G_AB(real_A)
real_B = Variable(imgs["B"].type(Tensor))
fake_A = G_BA(real_B)
# Arange images along x-axis
real_A = make_grid(real_A, nrow=5, normalize=True)
real_B = make_grid(real_B, nrow=5, normalize=True)
fake_A = make_grid(fake_A, nrow=5, normalize=True)
fake_B = make_grid(fake_B, nrow=5, normalize=True)
# Arange images along y-axis
image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
save_image(image_grid, os.path.join(root_path, "images/%s/%s.png" % (opt.dataset_name, batches_done)), normalize=False)
# ----------
# Training
# ----------
if __name__ == '__main__':
prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs):
for i, batch in enumerate(dataloader):
# Set model input
real_A = Variable(batch["A"].type(Tensor))
real_B = Variable(batch["B"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)
# ------------------
# Train Generators
# ------------------
G_AB.train()
G_BA.train()
optimizer_G.zero_grad()
# Identity loss
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)
loss_identity = (loss_id_A + loss_id_B) / 2
# GAN loss
fake_B = G_AB(real_A)
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Cycle loss
recov_A = G_BA(fake_B)
loss_cycle_A = criterion_cycle(recov_A, real_A)
recov_B = G_AB(fake_A)
loss_cycle_B = criterion_cycle(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total loss
loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
loss_G.backward()
optimizer_G.step()
# -----------------------
# Train Discriminator A
# -----------------------
optimizer_D_A.zero_grad()
# Real loss
loss_real = criterion_GAN(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2
loss_D_A.backward()
optimizer_D_A.step()
# -----------------------
# Train Discriminator B
# -----------------------
optimizer_D_B.zero_grad()
# Real loss
loss_real = criterion_GAN(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2
loss_D_B.backward()
optimizer_D_B.step()
loss_D = (loss_D_A + loss_D_B) / 2
# --------------
# Log Progress
# --------------
# Determine approximate time left
batches_done = epoch * len(dataloader) + i
batches_left = opt.n_epochs * len(dataloader) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
prev_time = time.time()
# Print log
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
% (
epoch,
opt.n_epochs,
i,
len(dataloader),
loss_D.item(),
loss_G.item(),
loss_GAN.item(),
loss_cycle.item(),
loss_identity.item(),
time_left,
)
)
# If at sample interval save image
if batches_done % opt.sample_interval == 0:
sample_images(batches_done)
# Update learning rates
lr_scheduler_G.step()
lr_scheduler_D_A.step()
lr_scheduler_D_B.step()
if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
# Save model checkpoints
torch.save(G_AB.state_dict(), os.path.join(root_path, "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch)))
torch.save(G_BA.state_dict(), os.path.join(root_path, "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch)))
torch.save(D_A.state_dict(), os.path.join(root_path, "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch)))
torch.save(D_B.state_dict(), os.path.join(root_path, "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch)))
dataset.py程序如下:
import glob
import random
import os
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
def to_rgb(image):
rgb_image = Image.new("RGB", image.size)
rgb_image.paste(image)
return rgb_image
class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
self.transform = transforms.Compose(transforms_)
self.unaligned = unaligned
self.files_A = sorted(glob.glob(os.path.join(root, "%sA" % mode) + "/*.*"))
self.files_B = sorted(glob.glob(os.path.join(root, "%sB" % mode) + "/*.*"))
def __getitem__(self, index):
image_A = Image.open(self.files_A[index % len(self.files_A)])
if self.unaligned:
image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
else:
image_B = Image.open(self.files_B[index % len(self.files_B)])
# Convert grayscale images to rgb
if image_A.mode != "RGB":
image_A = to_rgb(image_A)
if image_B.mode != "RGB":
image_B = to_rgb(image_B)
item_A = self.transform(image_A)
item_B = self.transform(image_B)
return {"A": item_A, "B": item_B}
def __len__(self):
return max(len(self.files_A), len(self.files_B))
model.py程序如下:
import torch.nn as nn
import torch.nn.functional as F
import torch
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
if hasattr(m, "bias") and m.bias is not None:
torch.nn.init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
##############################
# RESNET
##############################
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
)
def forward(self, x):
return x + self.block(x)
class GeneratorResNet(nn.Module):
def __init__(self, input_shape, num_residual_blocks):
super(GeneratorResNet, self).__init__()
channels = input_shape[0]
# Initial convolution block
out_features = 64
model = [
nn.ReflectionPad2d(channels),
nn.Conv2d(channels, out_features, 7),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Downsampling
for _ in range(2):
out_features *= 2
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Residual blocks
for _ in range(num_residual_blocks):
model += [ResidualBlock(out_features)]
# Upsampling
for _ in range(2):
out_features //= 2
model += [
nn.Upsample(scale_factor=2),
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Output layer
model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
##############################
# Discriminator
##############################
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
channels, height, width = input_shape
# Calculate output shape of image discriminator (PatchGAN)
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
def discriminator_block(in_filters, out_filters, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(channels, 64, normalize=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1)
)
def forward(self, img):
return self.model(img)
utils程序如下:
import random
import time
import datetime
import sys
from torch.autograd import Variable
import torch
import numpy as np
from torchvision.utils import save_image
class ReplayBuffer:
def __init__(self, max_size=50):
assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
self.max_size = max_size
self.data = []
def push_and_pop(self, data):
to_return = []
for element in data.data:
element = torch.unsqueeze(element, 0)
if len(self.data) < self.max_size:
self.data.append(element)
to_return.append(element)
else:
if random.uniform(0, 1) > 0.5:
i = random.randint(0, self.max_size - 1)
to_return.append(self.data[i].clone())
self.data[i] = element
else:
to_return.append(element)
return Variable(torch.cat(to_return))
class LambdaLR:
def __init__(self, n_epochs, offset, decay_start_epoch):
assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
self.n_epochs = n_epochs
self.offset = offset
self.decay_start_epoch = decay_start_epoch
def step(self, epoch):
return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
更多推荐


所有评论(0)