ControlNet训练代码解析
ControlNet训练代码解析
1 训练代码
tutorial_train.py
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
model = create_model('./models/cldm_v15.yaml').cpu() #创建模型,这里会调用函数如1.1.1所示
model.load_state_dict(load_state_dict(resume_path, location='cpu')) #加载参数
#输入几个参数变量,具体修改的是创建model内部参数
model.learning_rate = learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control
# Misc
#数据集dataset
dataset = MyDataset() #创建dataset,例如1.1.2所示
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
logger = ImageLogger(batch_frequency=logger_freq)
#由于利用了pytorch lightning框架,可以使用trainner来快速训练模型
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
# Train!
#直接传入model和dataloader即可
trainer.fit(model, dataloader)
1.1 创建模型
1.1.1 create_model.py
cldm/model.py中的create_model.py
def create_model(config_path):
config = OmegaConf.load(config_path)#参数预处理
model = instantiate_from_config(config.model).cpu()#依据'./models/cldm_v15.yaml'配置文件来创建模型
print(f'Loaded model config from [{config_path}]')
return model
instantiate_from_config()函数
该函数能够根据配置文件来创建出指定路径下模型及其参数的模型。这个函数是非常常用的。其中config字典类似于:
Config = { 'target': path1.path2.module_1_name,
'params': { 'para_1 ': value_a,
'para_2 ': value_b,
'module_2':{ 'target': path1.path2.module_2_name,
'params': { 'para_3 ': value_c,
'module_3':{ 'target': path1.path2.module_3_name,
'params' : {'para_4': value_d }
}}}}}
- 在 instantiate_from_config 返回对应的类的实例,
- 返回的实例是以params对应的值初始化的
- params对应的值是同等格式的字典。
也就是说,config中可以像上面的例子一样,设置好嵌套的各个模块,并且在模块实例化时读取传入的config,在模块的__init__中继续调用instantiate_from_config就可以实现各个模块嵌套式的实例化。具体的例子可以看第三篇。
1.1.2 dataset.py
import json
import cv2
import numpy as np
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
self.data = []
with open('./training/fill50k/prompt.json', 'rt') as f:
for line in f:
self.data.append(json.loads(line))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
source_filename = item['source']
target_filename = item['target']
prompt = item['prompt']
source = cv2.imread('./training/fill50k/' + source_filename)
target = cv2.imread('./training/fill50k/' + target_filename)
# Do not forget that OpenCV read images in BGR order.
source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
# Normalize source images to [0, 1].
source = source.astype(np.float32) / 255.0
# Normalize target images to [-1, 1].
target = (target.astype(np.float32) / 127.5) - 1.0
return dict(jpg=target, txt=prompt, hint=source)
2 模型具体流程
由于使用了pytorch-lightning框架,所以创建好的模型能够自动训练,只需要重构几个重要函数即可。具体可以查看pytorch-lightning文档。
2.1 配置文件
model:
target: cldm.cldm.ControlLDM
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
control_key: "hint"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
only_mid_control: False
control_stage_config:
target: cldm.cldm.ControlNet
params:
image_size: 32 # unused
in_channels: 4
hint_channels: 3
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
unet_config:
target: cldm.cldm.ControlledUnetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
从配置文件我们可以看出,我们首先创建的是ControlLDM文件。另外通过下面读代码我们可以看出具体的继承关系如下:
ControlLDM( LatentDiffusion( DDPM( pl.LightningModule ) ) ) controlLDM模型使用到的内容均在cldm/cldm.py当中,我们接下连介绍主要围绕文件中几个模型类来了解。
2.2 ControlLDM类
class ControlLDM(LatentDiffusion):
2.2.1 初始化
这地方使用了super().init证明继承的LatenDiffusion当中的也随同初始化。多出的几个初始化内容就是关于control所用到的参数。
def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
super().__init__(*args, **kwargs)
#多加入了Control_model
self.control_model = instantiate_from_config(control_stage_config) #control_model就是我们这里的controlNet模型类,我们2.3节讲解
self.control_key = control_key
self.only_mid_control = only_mid_control
self.control_scales = [1.0] * 13
2.2.2 get_input()
- DDPM中函数这个get_input函数用于处理批次中的数据,将其转换为适合模型输入的格式。
- Laten Diffusion中该方法用于预处理输入数据并提取第一阶段编码,以及在需要时提取条件信息。
- ControlNet中该函数获得latent z , condition的feature map, control的信息[B, C, H, W]
具体的区分可以看: controlNet的get_input函数详解
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):
#这里的batch与latent diffusion中batch差别就是加入了hint信息,具体传入是dataset中内容
'''
获得latent z , condition的feature map, control的信息[B, C, H, W]
'''
# 通过使用super()调用父类LatentDiffusion的get_input方法,
# ControlLDM类可以在继承LatentDiffusion类的基本功能的基础上,扩展和定制自己的功能。
# x: 即latent z
# c: 将condition编码为feature map, 特征映射
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs) # first_stage_key=jpg
control = batch[self.control_key] # control_key=hint
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, 'b h w c -> b c h w')
# 这里的control还是[B,C,H,C],而不是feature map
control = control.to(memory_format=torch.contiguous_format).float()
return x, dict(c_crossattn=[c], c_concat=[control])
# { "c_crossattn": [c], "c_concat": [control] };这个字典的形式使得模型在后续处理中可以灵活地处理不同的条件信息。
# 例如,c_crossattn 可能用于跨注意力机制,而 c_concat 可能用于连接操作。将这些信息存储在一个字典里,使得它们可以方便地在后续操作中被提取和使用。
# 返回包含输入张量 x 和一个字典的元组,
# 该字典包含条件信息 c(用于交叉注意力)和控制信息 control(用于连接操作)。
2.2.3 apply_model
这里的apply_model主要用于处理步骤,和模型forward有相似之处。DDPM和LatenDiffusion中主要调用model.apply这一方法来进行图像噪声的预测。apply_model的主要用处是在继承的DDPM和LatenDiffusion中用到,具体的原理内容可以查看: Stable Diffusion 代码 (三)
在这里由于预测噪声需要加入control信息,所以需要修改具体流程。
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
'''
这里的cond就是get_input中返回的的dict(c_crossattn=[c], c_concat=[control])
self.model = DiffusionWrapper(pl.LightningModule)
self.model.diffusion_model = 配置文件中创建的uet_config,也就是cldm.cldm.ControlledUnetModel
'''
cond_txt = torch.cat(cond['c_crossattn'], 1)
if cond['c_concat'] is None:
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
else:
# 先获取control列表,即获取input_blocks和middle_block的每个zero_conv的输出并保存在control中
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
control = [c * scale for c, scale in zip(control, self.control_scales)]
# 预测noise
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
return eps
2.4 其余函数
def get_unconditional_conditioning(self, N):
return self.get_learned_conditioning([""] * N)
def log_images():
return log
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
return samples, intermediates
def configure_optimizers(self):
return opt
def low_vram_shift(self, is_diffusing):
2.3 ControlNet
class ControlNet(nn.Module):
这个类主要是给出如何输出control信息,这些control信息存在于每个zero_conv之后,要传入到原始locked的SD模型中。
def forward(self, x, hint, timesteps, context, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
# 原始Condition需要经过卷积,并进行编码.然后再经过零卷积,得到guided_hit。结果可以与x直接相加
guided_hint = self.input_hint_block(hint, emb, context)
outs = []
#将原始输入x改变类型
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
#在这个module里面放的不同的block
if guided_hint is not None:
#这是第一个加入condition的操作
#这部分与论文结构图不太一样,这部分是对输入x进行了一个卷积才进行相加
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
#outs中存放每经过一个block和zero_conv的结果
outs.append(zero_conv(h, emb, context))
#outs最终经过中间层的block和zero_conv
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
return outs
2.4 ControlledUnetModel(UNetModel)
这个类继承了UNetModel类,在DDPM原始的类中主要作用是向前传播,一个大型U-Net模型
2.4.1 DDPM中的UNetModel
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
2.4.2 ControlledUnetModel
由于controlNet模型加入了一些control信息,需要重新设计U-Net模型,将control信息加入U-Net中
#实际被控制的部分。继承UNetModel类
class ControlledUnetModel(UNetModel):
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
hs = []
#lock不更新梯度
with torch.no_grad():
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
#这个循环是lock_block的编码block阶段
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
#再经过中间层
h = self.middle_block(h, emb, context)
#接下来就逐步pop弹出trainable copy的输出结果,先弹出中间层
if control is not None:
h += control.pop()
# 实际开始control,
# 实际control方式:将经过middle block层的输出与(相应encoder block + control_outs)进行concat
for i, module in enumerate(self.output_blocks):
#这里给出一个判断,是否只是用一个zero_conv,保证了低配置的训练
if only_mid_control or control is None:
h = torch.cat([h, hs.pop()], dim=1)
else:
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
return self.out(h)
参考链接
Stable Diffusion 代码(一)
stable Diffusion 代码(二)
stable Diffusion 代码(三)
Stable Diffusion 代码(四)
controlNet的get_input函数详解
ControlNet, ControlledUnetModel, ControlLDM三者之间的关系
Diffusion扩散模型学习2——Stable Diffusion结构解析-以文本生成图像(文生图,txt2img)为例
万字长文解读Stable Diffusion的核心插件—ControlNet
controlnet前向代码解析
AIGC专栏2——Stable Diffusion结构解析-以文本生成图像(文生图,txt2img)为例
更多推荐
所有评论(0)