Stable Diffusion也能“吃得多长得壮”?用数据增强喂出更强AI模型

当你的训练数据不够用,AI也会“营养不良”

做图像生成的同学都知道,Stable Diffusion 就像一只胃口极大的猫,喂饱了它就能撸出毛茸茸的 4K 大图;可一旦断粮,它立刻翻脸给你一张“克苏鲁”风格的崩坏脸。现实是,高质量成对数据(文本-图片)贵得离谱,开源社区能薅的羊毛也就那么几撮。于是大家开始动起“数据增强”的歪脑筋:既然买不起新食材,那就把旧菜做出满汉全席的味道。

可别以为数据增强只是“把图往左拧 90°”那么无聊。今天咱们就聊聊怎样把 Stable Diffusion 这只挑食的猫喂成肌肉猛男——既不让它吃撑,也别让它吃错药。文章很长,代码很多,准备好可乐和辣条,咱们慢慢唠。


Stable Diffusion不是万能的,但数据增强能让它更接近万能

先给刚入坑的同学补一句:Stable Diffusion 本质上是“噪点去除器”,它学会的是“从随机高斯噪点里一点点把猫片抠出来”。它吃进去的其实是“文本编码 + 噪点图”,吐出来的是“去噪后的像素”。
换句话说,只要你能把“文本-噪点-像素”三元组玩出花来,就能让模型见到更多世面,而无需真的去 Flickr 买 8K 分辨率的原图。

数据增强的核心目标就是:在不改变“语义”的前提下,疯狂给模型加餐。
加得巧,模型见多识广;加得烂,模型直接“吃坏肚子”——生成一堆四只胳膊的二次元老婆,你还不敢发朋友圈。


图像生成模型的隐痛:高质量训练样本太稀缺

Stable Diffusion 官方在 Laion-5B 上挑了 2B 张图,过滤后只剩 600M,再精挑 170M 才拿去训。听着很多?分到 1000 个类别后,每个类别也就十几万张。更要命的是,中文语料占比低得可怜,你想让它画“糖葫芦”它给你“冰糖手电筒”。
自己爬数据吧,版权、NSFW、低清、水印、表情包…… 爬完清洗完,发现硬盘满了,预算光了,老板还问你怎么还没上线。

于是,数据增强成了“穷鬼”开发者最后的倔强:
“老子没钱买图,还不能再造一点吗?”


数据增强不只是“复制粘贴”,它是给模型开小灶

很多人把增强想成“一张图变十张图”,其实更应该理解为“给模型一次额外的小测验”。
同一张猫片,轻微旋转后,模型就要重新猜“这是猫还是狗”;文本 prompt 里把“a cute cat”换成“a fluffy kitten”,模型就得学会“fluffy ≈ cute ≈ kitten”的微妙差异。
增强的精髓是:让模型在“看似不同、实则同义”的样本上反复横跳,从而学到更鲁棒的流形。

下面这张脑图(伪)概括了 Stable Diffusion 常用的增强路线:

原始样本
├─ 像素级:旋转、缩放、裁剪、翻转、HSV、CutMix
├─ 语义级:文本替换、风格化、随机遮罩、DiffusionDet 前景粘贴
├─ 模型级:对抗样本、EMA 参数扰动、扩散轨迹混合
└─ 隐空间:在 VAE 潜变量里做插值、打乱通道、加噪

接下来咱们挨个拆,拆完给代码,能抄就抄,抄不了就改。


从旋转裁剪到语义合成:数据增强的十八般武艺

像素级:先给图片做“广播体操”

最经典也最安全,适合在 dataset 的 __getitem__ 里直接塞:

import cv2
import numpy as np
from torchvision import transforms

class PixelAug:
    def __init__(self, size=512):
        self.tf = transforms.Compose([
            transforms.RandomResizedCrop(size, scale=(0.8, 1.0)),
            transforms.RandomRotation(10),          # 左右晃 10°
            transforms.ColorJitter(brightness=.1, contrast=.1, saturation=.1, hue=.02),
            transforms.RandomHorizontalFlip(p=0.5),
        ])
    def __call__(self, img):
        # img: PIL Image
        return self.tf(img)

注意别用太猛:旋转 90° 容易把“狗”转成“天花板上的狗”,Stable Diffusion 会一脸懵。

语义级:文本也能玩“偷梁换柱”

SD 训练时需要配对的文本,咱们可以把原始 caption 做轻量改写,让模型见识“同一个意思 N 种说法”。

import random

templates = [
    "a photo of {}",
    "a picture of {}",
    "{}",               # 原词
    "a high-quality image of {}",
    "a detailed illustration of {}"
]

def augment_caption(raw: str) -> str:
    # raw: 糖葫芦
    tmpl = random.choice(templates)
    return tmpl.format(raw)

如果想再“骚”一点,可以上中文近义词词典(如 OpenHowNet)做替换:

synonym = {
    "猫": ["喵星人", "猫咪", "小猫"],
    "狗": ["汪星人", "狗子", "小狗"]
}
def cn_synonym_replace(caption: str) -> str:
    for k, v in synonym.items():
        if k in caption:
            return caption.replace(k, random.choice(v), 1)
    return caption

风格化:把照片变成油画,再变回来

Stable Diffusion 本身带一个“img2img”流水线,我们可以先用轻量风格化模型(如 AnimeGAN、Fast Neural Style)把原图变风格,再送进训练集。
这样模型能学到“内容不变、画风变”的鲁棒性,生成时就不会只会“二次元”或“照片”二选一。

from torchvision.transforms import functional as F
from style_transfer import whitebox_cartoon  # 伪代码,替换成你的风格化包

class StyleAug:
    def __init__(self, p=0.3):
        self.p = p
    def __call__(self, img):
        if random.random() < self.p:
            img_np = np.array(img)
            cartoon = whitebox_cartoon(img_np)
            return F.to_pil_image(cartoon)
        return img

随机遮罩:让模型做“完形填空”

把图随机挖掉 30% 像素,强迫模型靠周围和文本描述脑补。思路来自 Stable Diffusion 自己的 Inpainting 任务,训练时加一把,能显著提升细节一致性。

def random_mask(img, ratio=0.3):
    w, h = img.size
    mask = np.ones((h, w), np.uint8)
    # 随机画若干矩形洞
    for _ in range(random.randint(3, 7)):
        x1 = random.randint(0, w)
        y1 = random.randint(0, h)
        x2 = random.randint(x1, w)
        y2 = random.randint(y1, h)
        mask[y1:y2, x1:x2] = 0
    img_np = np.array(img)
    img_np[mask == 0] = 255   # 白洞
    return Image.fromarray(img_np)

实战派最爱的增强技巧:哪些方法真能提升生成质量?

下面是一份“亲测有效”的增强配方,来自我们去年在 8 张 RTX 3090 上微调 1.5 模型的血泪史。
训练目标:让基础模型更懂“国风插画”。

增强手段 占比 FID↓ CLIP↑ 肉眼观感
原图 50% 18.3 0.782 baseline
随机裁剪+颜色扰动 20% 17.9 0.788 更鲜艳
文本模板替换 15% 17.5 0.791 语义更贴
风格化(国画+浮世绘) 10% 17.2 0.794 画风更稳
随机遮罩 5% 17.1 0.795 细节更少崩

结论:

  1. 文本增强 ROI 最高,几行代码就能涨 CLIP。
  2. 像素级增强别超过 30%,否则模型会“过拟合”到增强域。
  3. 风格化要控制比例,太多会“污染”原始照片分布,导致生成“国画脸”糊成一团。

小心陷阱!这些增强操作反而会让模型学歪

  1. 180° 垂直翻转:把“狗”翻成“四脚朝天”的奇行种,文本却没改,模型直接怀疑狗生。
  2. 强颜色抖动:把“蓝天”抖成“绿天”,文本还是“blue sky”,模型以为“blue ≈ 绿”。
  3. 过度 CutMix:把猫头和狗身拼一起,文本写“a cat and a dog”,结果生成“猫狗兽”。
  4. 文本同义词太离谱:把“汽车”换成“火车”,像素却不变,模型被迫学会“汽车 ≈ 火车”,以后画汽车可能给你加烟囱。

一句话:凡是会让“像素-文本”对不上号的操作,都要谨慎;要么改图,要么改文本,千万别只改一半。


调试指南:当生成结果变糊、变怪、变离谱时怎么办

Step 1:先排除增强背锅
把增强全关掉,训 500 步看样本。如果依旧崩,说明不是增强的锅,可能是学习率/CFG/过拟合。

Step 2:单变量回滚
把增强按强度排序,从最强的开始回滚。我们发现“风格化”最容易导致颜色漂移,回滚掉立刻正常。

Step 3:可视化增强样本
把增强后的图+文本写进 TensorBoard,每 50 步随机抽 8 张。眼见为实,一眼就能发现“绿天”这种奇葩。

Step 4:FID/CLIP 双指标
只看 FID 会骗人——增强后图像分布变了,FID 可能升高;但 CLIP 分数上涨,说明语义对齐更好。建议双指标一起盯。

Step 5:小规模过拟合实验
拿 100 张图疯狂过拟合,观察增强样本会不会出现“猫狗兽”。如果会,立刻调整比例或策略。


开发者的私藏锦囊:自动化增强流水线怎么搭才高效

1. 基于 config 的“可插拔”增强

# aug_config.yaml
pixel:
  RandomResizedCrop: {scale: [0.8, 1.0]}
  ColorJitter: {brightness: 0.1, hue: 0.02}
text:
  templates: ["a photo of {}"]
  synonym: true
style:
  p: 0.2
  library: [anime, ink_wash, ukiyo_e]
mask:
  p: 0.05
  holes: [3, 7]
from omegaconf import OmegaConf
from torchvision import transforms

def build_aug(cfg):
    tf_list = []
    if cfg.pixel:
        tf_list.append(PixelAug(cfg.pixel))
    if cfg.style:
        tf_list.append(StyleAug(cfg.style.p))
    if cfg.mask:
        tf_list.append(MaskAug(cfg.mask.p))
    return transforms.Compose(tf_list)

cfg = OmegaConf.load('aug_config.yaml')
augment = build_aug(cfg)

2. 多进程+GPU 风格化

风格化模型通常吃 GPU,如果在线做会拖慢训练。可以用 torch.multiprocessing 提前把风格化缓存成 .jpg,训练时直接读盘。

def worker(idx, img_path, style_path, gpu_id):
    # 把 img_path 风格化后写到 style_path
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
    model = load_style_model()
    img = Image.open(img_path)
    stylized = model(img)
    stylized.save(style_path)

def batch_style_imgs(img_list, out_dir, num_gpus=4):
    pool = multiprocessing.Pool(num_gpus)
    for idx, ip in enumerate(img_list):
        gpu_id = idx % num_gpus
        op = os.path.join(out_dir, f"{idx:06d}.jpg")
        pool.apply_async(worker, (idx, ip, op, gpu_id))
    pool.close(); pool.join()

3. 动态混合采样

有时候你想“前期多增强、后期少增强”,可以用 torch.utils.data.ConcatDataset + 自定义 Sampler

class RatioSampler(torch.utils.data.Sampler):
    def __init__(self, aug_ds, clean_ds, aug_ratio=0.5):
        self.aug_ds = aug_ds
        self.clean_ds = clean_ds
        self.aug_ratio = aug_ratio
    def __iter__(self):
        a_idx = list(range(len(self.aug_ds)))
        c_idx = list(range(len(self.clean_ds)))
        random.shuffle(a_idx); random.shuffle(c_idx)
        # 按比例混合
        total = len(a_idx) + len(c_idx)
        ratio = int(self.aug_ratio * total)
        chosen = (a_idx[:ratio] + c_idx[:total - ratio])
        random.shuffle(chosen)
        return iter(chosen)

4. 版本管理:给增强打 tag

dvcwandb.Artifact 把增强后的数据集也做版本管理,防止“今天跑得好好的,明天换机器就复现不了”的尴尬。


别再死磕算力了,聪明地扩充数据才是性价比之王

写到这儿,估计有同学嘀咕:搞这么多花活,不如直接 A100 拉满,数据不够就再爬 100G?
兄弟,爬 100G 简单,洗 100G 难;洗完了发现版权 30% 侵权,法务让你全删,你哭都来不及。
数据增强虽然不能凭空造出新像素,但它能把“旧像素”玩出 10 倍价值,成本只是几行代码和一点电费。
尤其在风格化微调、垂直领域(医疗、建筑、国风、像素游戏)里,高质量外网图根本找不到,增强就是唯一的出路。

最后,送你一份“增强 checklist”,下次微调前照着打钩,少踩坑:

  • 文本和像素同步改,别只改一半
  • 增强后手动抽查 100 张,肉眼对语义
  • 训练前 500 步先关增强,确认 baseline
  • FID/CLIP 一起盯,单指标骗人
  • 风格化缓存,别让 GPU 等 CPU
  • 版本管理,增强数据集也要打 tag

祝你下次微调出图即大片,生成的小姐姐十指健全,老板开心,你也早下班。

在这里插入图片描述

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐