一、环境准备

1.软硬件环境要求

预装有PyTorch,支持GPU的Jupyter Notebook

2.准备运行环境

2.1 Colab 在线Jupyter Notebook运行环境

Colab地址 https://colab.research.google.com/#

2.1 新建你的第一个Notebook,准备开启PyTorch之旅

二、第二个PyTorch模型

1. CycleGAN

1.1 基于ResNet模型库,初始化定义ResNetGenerator类

netG = ResNetGenerator()

解读

  • netG模型已创建,包含的是随机权重

1.2 NetG模型实例,加载模型参数权重pth文件

model_path = '/content/horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

解读

  • 运行一个在horse2zebra数据集上预训练生成器模型

  • 该数据集包含2个集合,分别为1068张马的照片,和1335张斑马的照片

知识点

① 模型权重数据
  • 模型权重保存文件,xxx.pth文件,该文件是一个模型张量参数的pickle文件

  • torch.load(model_path),可加载出model_data

② 模型数据加载
  • 使用模型的load_state_dict()方法,可讲权重加载到ResNetGenerator中

  • 加载后,netG模型已获得在训练中需要的所有知识!!

1.3 运行NetG模型,设置eval()模式

netG.eval()

知识点

③ 运行模型eval()

1.4 导入图像

1.4.1 导入PIL(Python Imaging Library,Python图像库)
from PIL import Image
from torchvision import transforms

知识点

④ PIL Python图像库

PIL图像库

⑤ Torch格式转化 transforms

transforms函数,Torch库中用于格式转化的函数

1.4.2 图像Tensor化,预处理管理处理
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.ToTensor()
    ]
)

解读

  • 定义输入变换,确保数据已正确的状态和大小进入网络

知识点

⑥ 图片输入变换preprocess(基于transforms函数)
  • transforms.Resize(256),图片格式变换

  • transforms.ToTensor(),图片格式Tensor向量化

2.4.3 加载图片
img = Image.open("/content/horse.jpg")
img

1.5 图像转化

1.5.1 图片数据转换及预处理
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t,0)

知识点

⑦ 图片格式处理
  • 将预处理管道的图片数据,传入torch.unsqueeze()归一化处理

1.5.2 模型转换图片
batch_out = netG(batch_t)

1.5.3 图片输出
out_t = (batch_out.data.squeeze() + 1.0) / 2.0
out_img = transforms.ToPILImage()(out_t)
out_img

知识点

⑧ 将Tensor向量数据,转换成图片
  • transforms.ToPILImage()

最终生成的图片,把马渲染成了斑马,看着像那么回事儿👀

三、小结

至此,CycleGAN网络模型,通过生成器与鉴别器的对弈,持续生成预测结果。

从输出结果来看,成功的把一匹马转换成斑马,虽然生成效果还不算完美,但确实做到了~

CycleGAN模型背后的故事

1. GAN游戏

GAN是生成式对抗网络的缩写(Generative Adversarial Network),引入两个网络概念,分别是

  • 生成器网络(generator network),扮演画家,负责从任意输入开始生成逼真的图像。
  • 判别器网络(discriminiator network),扮演艺术史学工作者,负责判断给定的图像是由生成器生成的,还是一幅真实的图像。

两个网络都是基于彼此网络的结果进行训练的,并推动彼此对网络参数进行优化。通过两个网络的博弈,最终生成逼真的图像。

2. CycleGAN网络

CycleGAN是指循环生成式对抗网络,可以将一个领域的图像转换成另一个领域的图像,而不需要在训练集中显式提供匹配对。这种方式实现了从有监督学习,到无监督学习的转变。

CycleGAN推动模型训练无监督学习的发展,为后来的生成式AI,AIGC等领域发展奠定了基础。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐