Stable Diffusion搞大模型?分布式训练才是你扛住TB级数据的救命稻

Stable Diffusion搞大模型?分布式训练才是你扛住TB级数据的救命稻草

引言:为啥你本地跑Stable Diffusion训练总崩?不是显卡不行,是你没上分布式!

兄弟们,先别急着砸键盘。上周我那个96G显存的A100机器,跑Stable Diffusion 2.1训练,batch size设到4就OOM,我当时就懵了——这他娘的是显存黑洞吧?后来才反应过来,不是显卡不行,是我脑子不行:单机训练Stable Diffusion,就像用三轮车拉高铁,再猛的发动机也扛不住啊。

你以为Stable Diffusion就是"输入文字-输出老婆"?错!这玩意儿背后是个吃显存不吐骨头的怪兽。今天咱们就掰开揉碎聊聊,怎么靠分布式训练把这头怪兽驯成乖猫咪。全程干货,代码管饱,坑点全给你标出来,学不会你来打我(当然打不到)。

Stable Diffusion训练到底有多"吃"资源

从一张图到亿级参数:Stable Diffusion背后的算力黑洞

先给你整个直观的:Stable Diffusion 1.5版本,光UNet就860M参数,VAE编码器83M,CLIP文本编码器123M,加起来快11亿参数。这还没算你微调时候新增的cross-attention层。更离谱的是SDXL,UNet直接飙到3.5B参数,显存占用直接翻倍。

来看个真实惨案:我用单张A100 80G跑SDXL微调,batch size=1,gradient checkpointing全开,mixed precision开到BF16,显存占用78G——就差2G就爆,训练速度2.5秒/iter。按LAION-5B的5B张图算,跑完得等到我孙子打酱油。

# 实测SDXL显存占用的土办法
import torch
from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.bfloat16
).to("cuda")

# 打印显存占用
print(f"初始显存: {torch.cuda.memory_allocated()/1e9:.2f} GB")
print(f"初始缓存: {torch.cuda.memory_reserved()/1e9:.2f} GB")

# 跑一张图试试
prompt = "a cute cat wearing sunglasses, digital art"
image = pipe(prompt, num_inference_steps=30).images[0]

print(f"生成后显存: {torch.cuda.memory_allocated()/1e9:.2f} GB")
print(f"生成后缓存: {torch.cuda.memory_reserved()/1e9:.2f} GB")

输出直接吓尿:生成一张图就占25G显存,训练时候还得存梯度、优化器状态,显存直接乘以4-6倍。这就是为什么你batch size永远不敢开大的根本原因。

数据量动不动就几十TB,单机训练?别闹了兄弟

LAION-5B知道吧?5B张图,240TB原始数据。就算你压缩成256x256的WebDataset,也得个十几TB。单机训练?按你那个1Gbps的小水管,下完数据得三个月,训练完得三年。更惨的是,你以为数据够了就行?太天真!

Stable Diffusion训练有个变态设定:每轮epoch要shuffle全部数据。单机训练时候,你那个机械硬盘读数据的速度,还没GPU算得快。结果就是——GPU天天摸鱼等数据,你看着50%的利用率干瞪眼。

# 来看个数据加载速度的实测
import time
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])

# 模拟LAION子集,假设有100W张图
dataset = datasets.ImageFolder("/data/laion_subset", transform=transform)
dataloader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=True)

start = time.time()
for i, (images, _) in enumerate(dataloader):
    if i >= 100:  # 跑100个batch看看
        break
    images = images.cuda()
    # 模拟GPU计算时间
    torch.matmul(images.view(images.size(0), -1), 
                torch.randn(3*256*256, 512).cuda())
    
print(f"100个batch耗时: {time.time()-start:.2f}秒")
print(f"单batch平均: {(time.time()-start)/100:.3f}秒")

跑完你会发现:数据加载平均每个batch要0.15秒,而GPU计算只要0.05秒。70%时间GPU在等数据,这训练个寂寞?分布式训练这时候就派上用场了——多机同时加载不同分片,I/O瓶颈直接打散。

分布式训练不是玄学,是刚需

啥叫分布式训练?说白了就是"人多力量大"

别被"分布式"这仨字吓到,本质就一句话:把大任务拆成小任务,扔给多台机器一起干。就像你搬家,一个人扛沙发得累死,叫上五个哥们,一人抬一角,5分钟搞定。

分布式训练有三种"搬砖姿势":

数据并行:最土但最好用。每台机器塞一份完整模型,各自啃不同数据,最后把梯度平均一下。就像五个厨师炒同一道菜,各自炒自己的那份,最后把味道调到一起。

模型并行:模型太大单卡塞不下?拆开!UNet的transformer层一人分几层,像拼乐高似的。但有个坑:层与层之间得传激活值,网络延迟直接拖垮速度。

流水线并行:模型并行plus版。把模型横着切,每台机器负责一段,像流水线工人。A卡算第1-5层时候,B卡同时算6-10层,C卡算11-15层,重叠起来延迟就小了。

# 来看个最土的数据并行PyTorch实现
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def train(rank, world_size):
    # 初始化进程组
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    
    # 创建模型,每个进程一份
    model = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1")
    model = model.to(rank)
    
    # 包装成DDP
    model = DDP(model, device_ids=[rank])
    
    # 创建dataset,每个进程加载不同分片
    dataset = LaionDataset(shard_index=rank, total_shards=world_size)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    for epoch in range(100):
        for batch in dataloader:
            images = batch["images"].to(rank)
            text = batch["text"]
            
            # 前向
            loss = model(images, text).loss
            
            # 反向
            optimizer.zero_grad()
            loss.backward()
            
            # DDP会自动同步梯度,不用你操心
            optimizer.step()
            
            if rank == 0:  # 只让主进程打印
                print(f"loss: {loss.item():.4f}")

# 启动多进程
torch.multiprocessing.spawn(train, args=(4,), nprocs=4)

看到没?核心就三行:init_process_group、DDP包装、多进程spawn。其他都跟单机一样,PyTorch帮你把梯度同步、参数更新全偷偷干了。

数据并行、模型并行、流水线并行——别被术语吓住,其实就三种"搬砖姿势"

来,咱们整点更骚的。SDXL 3.5B参数,单卡80G塞不下咋办?上模型并行!

# 模型并行示例:把UNet拆到两张卡
class ModelParallelUNet(torch.nn.Module):
    def __init__(self, original_model):
        super().__init__()
        # 假设original_model有24个transformer block
        self.first_half = torch.nn.Sequential(
            original_model.conv_in,
            original_model.time_proj,
            original_model.time_embedding,
            *original_model.down_blocks[:12]  # 前12层放卡0
        ).to("cuda:0")
        
        self.second_half = torch.nn.Sequential(
            *original_model.down_blocks[12:],
            original_model.mid_block,
            *original_model.up_blocks,
            original_model.conv_out
        ).to("cuda:1")
        
    def forward(self, x, timesteps, encoder_hidden_states):
        # 先在卡0算前一半
        x = self.first_half(x.to("cuda:0"))
        
        # 把中间激活值传到卡1(注意要detach,不然梯度图会炸)
        x = x.to("cuda:1")
        
        # 卡1算后一半
        x = self.second_half(x)
        return x

# 使用时候就这样
model = ModelParallelUNet(original_model)
output = model(noise, timesteps, text_embeds)

但注意:模型并行最坑的是中间激活值传输。上面这个例子,x从cuda:0传到cuda:1,如果feature map是64x64x320,那就是6464320*4字节=5MB。听起来小,但SDXL有24层,每层都要传,延迟直接爆炸。所以实际工程里,大家更爱用流水线并行:

# 流水线并行伪代码(需要fairscale库)
from fairscale.nn import Pipe

model = UNet2DConditionModel(...)  # 原始模型
# 把模型切成4份,每份6层
chunks = torch.nn.Sequential(
    torch.nn.Sequential(*list(model.children())[:6]),
    torch.nn.Sequential(*list(model.children())[6:12]),
    torch.nn.Sequential(*list(model.children())[12:18]),
    torch.nn.Sequential(*list(model.children())[18:]),
)

# 包装成pipeline,balance参数指定每块放几层
pipe = Pipe(chunks, balance=[6,6,6,6], devices=["cuda:0","cuda:1","cuda:2","cuda:3"], chunks=8)

流水线并行最香的是能重叠计算。当第1块在算第6层时候,第2块同时算第7-12层,像工厂流水线一样,理论上能把GPU利用率飙到90%+。但实现巨复杂,fairscale的Pipe有bug,PyTorch官方PiPPy又太新,所以小团队建议先上数据并行+DeepSpeed,等有钱了再折腾流水线。

主流分布式训练框架怎么选

PyTorch DDP真香?还是DeepSpeed更猛?

先说结论:DDP适合中小团队,DeepSpeed适合土豪,FSDP适合极客。为啥?看实测:

DDP优点:PyTorch官方亲儿子,稳定得像老狗。缺点:显存占用高,因为每张卡都要存完整模型+优化器状态。SDXL 3.5B参数,FP32就是14GB,AdamW还要存动量方差,直接*3,42GB没了。80G A100剩38G给激活值,batch size=2就OOM。

DeepSpeed的ZeRO优化器,把优化器状态、梯度、参数全分片,每张卡只存1/N。同样SDXL,4卡训练,每张卡只要存10.5GB优化器状态,省出30GB给batch size。实测batch size能拉到8,训练速度直接*4。

# DeepSpeed配置样例(保存为ds_config.json)
{
  "train_batch_size": 32,  # 总batch size=8*4卡
  "train_micro_batch_size_per_gpu": 8,
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": 1e-4,
      "betas": [0.9, 0.999],
      "eps": 1e-8
    }
  },
  "scheduler": {
    "type": "WarmupLR",
    "params": {
      "warmup_min_lr": 0,
      "warmup_max_lr": 1e-4,
      "warmup_num_steps": 1000
    }
  },
  "zero_optimization": {
    "stage": 2,  # ZeRO-2分片优化器状态和梯度
    "offload_optimizer": {
      "device": "cpu",  # 优化器状态放CPU,再省显存
      "pin_memory": true
    },
    "allgather_partitions": true,
    "allgather_bucket_size": 2e8,
    "overlap_comm": true,
    "reduce_scatter": true,
    "contiguous_gradients": true
  },
  "fp16": {
    "enabled": true,
    "auto_cast": false,
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
  }
}

用的时候巨简单:

import deepspeed

model = UNet2DConditionModel.from_pretrained(...)
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config="ds_config.json"
)

for epoch in range(100):
    for batch in dataloader:
        images = batch["images"].cuda()
        loss = model_engine(images, batch["text"]).loss
        model_engine.backward(loss)
        model_engine.step()

看到没?代码跟DDP几乎一样,但显存直接省一半。缺点是DeepSpeed装起来像拆炸弹:CUDA版本、PyTorch版本、NCCL版本,错一个就segfault。我当初装了三晚上,差点把服务器砸了。

Hugging Face Accelerate能不能救我这种懒人?

能!Accelerate就是DeepSpeed的"傻瓜模式"。上面DeepSpeed那堆配置,用Accelerate只要两行:

accelerate config  # 交互式配置,问你几个选择题
accelerate launch train.py  # 自动帮你加deepspeed、DDP、FSDP

train.py里完全不用改:

from accelerate import Accelerator

accelerator = Accelerator()
model = UNet2DConditionModel.from_pretrained(...)
optimizer = torch.optim.AdamW(model.parameters())

model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)

for epoch in range(100):
    for batch in dataloader:
        loss = model(**batch).loss
        accelerator.backward(loss)
        optimizer.step()

Accelerate会自动检测你之前config选的分布式策略,帮你偷偷加好。适合我这种"能少写一行绝不多写"的懒人。但注意:Accelerate的magic黑箱有时候抽风,出问题调试比直接写DDP还痛苦。建议先学会裸写DDP,再用Accelerate偷懒。

FSDP(Fully Sharded Data Parallel)是不是下一代答案?

是,也不是。FSDP是PyTorch 1.12+官方出的"真·分片",比DeepSpeed ZeRO更激进:不仅优化器状态分片,连模型参数都分片,前向时候再allgather。理论上能训更大的模型,但实测速度比ZeRO-2慢20%,比ZeRO-3慢10%。

# FSDP使用样例(需要pytorch>=2.0)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision

model = UNet2DConditionModel.from_pretrained(...)

# 配置混合精度
mp = MixedPrecision(
    param_dtype=torch.float16,
    reduce_dtype=torch.float16,
    buffer_dtype=torch.float16,
)

# 包装成FSDP
model = FSDP(
    model,
    mixed_precision=mp,
    device_id=torch.cuda.current_device(),
    # 用transformer层做分片单位
    auto_wrap_policy=transformer_auto_wrap_policy,
)

# 训练循环跟DDP一模一样

FSDP最大优点是官方维护,不用担心DeepSpeed那种"更新一次崩一次"的尿性。但目前生态还不够成熟,很多模型(特别是SDXL)需要自己写wrap_policy,调试起来头大。建议等PyTorch 2.2+再考虑上生产。

动手搭一个Stable Diffusion分布式训练环境

GPU集群怎么配才不浪费钱

血泪教训:别一上来就租8台A100!按需求渐进式扩容才是王道。我当初训SD 1.5,先用4台RTX 4090(24G)试水,batch size=4*4=16,训练一周效果差不多了,再扩容到8台。4090租金只要A100的1/3,性能却有A100的70%,性价比爆炸。

配置推荐(按预算排序):

  • 乞丐版:4台RTX 4090 + 10Gbps内网,月租$2000,能训SD 1.5
  • 温饱版:8台A100 40G + 25Gbps内网,月租$12000,能训SDXL
  • 土豪版:16台A100 80G + 100Gbps InfiniBand,月租$40000,能训SDXL+高分辨率

注意:GPU内存比算力更重要!SDXL训练时候,显存80G的A100能比40G版本batch size大3倍,训练速度反而快2倍。所以预算有限时候,优先堆显存,别盲目加卡。

NCCL、InfiniBand这些网络细节真的影响速度吗?

影响,而且巨大!我实测过:同样是8卡,10Gbps以太网vs 100Gbps InfiniBand,DDP训练速度差3倍。为啥?Stable Diffusion训练每个step都要同步梯度,860M参数*4字节=3.4GB,10Gbps网络传一次要2.7秒,而InfiniBand只要0.27秒。你GPU算得再快,也得等网络传完。

# 检测网络带宽的土办法
import torch
import torch.distributed as dist
import time

dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()

# 创建大tensor模拟梯度
tensor = torch.randn(860*1024*1024, dtype=torch.float32).cuda()

# warmup
for _ in range(10):
    dist.all_reduce(tensor)

# 正式测试
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
    dist.all_reduce(tensor)
torch.cuda.synchronize()
elapsed = time.time() - start

# 计算带宽
size_gb = tensor.numel() * 4 / 1e9  # 3.44 GB
bandwidth_gbps = size_gb * 2 * (world_size-1) * 100 / elapsed  # *2是收发双向
print(f"NCCL带宽: {bandwidth_gbps:.2f} Gbps")

跑完你会发现:普通10Gbps以太网实际只能跑7-8Gbps,而InfiniBand能跑满100Gbps。所以租云服务器时候,优先选带InfiniBand的机型,贵20%但训练快3倍,算下来反而省钱。

Docker + Slurm + Weights & Biases:我的土法炼丹流水线

环境配置最烦的是啥?依赖冲突!今天装个deepspeed,明天numpy升级,后天整个环境崩了。Docker就是救命稻草:

# Dockerfile
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel

# 装系统依赖
RUN apt-get update && apt-get install -y \
    git wget build-essential \
    libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1

# 装Python依赖
COPY requirements.txt /tmp/
RUN pip install --no-cache-dir -r /tmp/requirements.txt

# 预下载模型,避免每次启动都拉
RUN python -c "from diffusers import StableDiffusionPipeline; StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-1')"

# 工作目录
WORKDIR /workspace

构建完镜像,用Slurm提交任务:

#!/bin/bash
# submit.sh
#SBATCH --job-name=sd_train
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8
#SBATCH --time=7-00:00:00
#SBATCH --partition=gpu
#SBATCH --output=logs/%j.out

# 加载模块
module load cuda/12.1
module load nccl/2.18

# 启动训练
srun docker run --gpus all --network=host \
    -v /data:/data \
    -v $(pwd)/output:/workspace/output \
    my-sd-image:latest \
    accelerate launch --multi_gpu --num_processes 32 train.py \
    --config configs/sdxl_deepspeed.yaml

Slurm会自动帮你分配节点,docker保证环境一致,W&B记录实验:

# train.py里加几行
import wandb

wandb.init(
    project="stable-diffusion-distributed",
    name=f"sdxl_deepspeed_bs32_lr1e4_{os.environ.get('SLURM_JOB_ID')}",
    config={
        "batch_size": 32,
        "learning_rate": 1e-4,
        "model": "stabilityai/stable-diffusion-xl-base-1.0",
        "deepspeed_stage": 2,
    }
)

# 训练循环里
wandb.log({
    "loss": loss.item(),
    "lr": optimizer.param_groups[0]["lr"],
    "gpu_memory": torch.cuda.memory_allocated() / 1e9,
    "grad_norm": total_norm,
})

整套流程跑通后,提交任务只要:sbatch submit.sh,然后回家睡觉。第二天看W&B曲线,爽得一批。

踩过的坑比代码还多

梯度同步慢得像蜗牛?可能是你的batch size设反了

遇到过这种情况吗:8卡训练,速度还没单卡快?99%是batch size设反了。DDP默认是单卡batch size,总batch size=单卡卡数。但有些框架(特别是自己写的)反过来,总batch size当单卡,结果每张卡塞了8倍数据,显存直接炸,梯度同步数据量也8,不慢才怪。

# 正确姿势:明确指定单卡batch size
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)  # 这是单卡
model = DDP(model)
# 总batch size = 4 * world_size

# 错误姿势(别笑,我真见过)
total_batch_size = 32
per_gpu = total_batch_size // world_size  # 这看起来对,但world_size=1时候就崩

更坑的是gradient accumulation。有人为了省显存,设accumulation_steps=4,结果忘了同步时候要除以steps,梯度直接*4,loss爆炸。正确写法:

loss = loss / gradient_accumulation_steps
if (step + 1) % gradient_accumulation_steps == 0:
    optimizer.step()
    optimizer.zero_grad()

Loss突然爆炸?八成是不同节点随机种子没对齐

分布式训练最怕啥?各节点数据不一样!你shuffle时候没对齐种子,节点0加载了猫,节点1加载了狗,梯度平均完直接四不像。更惨的是dropout、数据增强,种子没对齐,模型行为完全随机。

# 对齐所有随机种子
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 每个进程用相同种子
set_seed(42 + rank)  # 注意:数据分片时候,每个进程要不同种子

但注意:数据分片时候,每个进程的dataloader要用不同种子,不然都加载同一份数据。正确做法是:

# 数据分片种子
data_seed = 42 + rank
dataset = LaionDataset(shard_index=rank, total_shards=world_size)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, 
                       generator=torch.Generator().manual_seed(data_seed))

显存看着够,一跑就OOM?分片策略可能压根没生效

用DeepSpeed时候,你以为开了ZeRO-3就能塞下任意大模型?太天真!我当初训SDXL,ZeRO-3开了,batch size=1还OOM,查了一晚上发现:UNet的cross-attention层用了nn.ParameterList,DeepSpeed识别不到,压根没分片!

# 错误写法:DeepSpeed识别不到
self.attention_layers = nn.ParameterList([nn.Linear(768, 320) for _ in range(24)])

# 正确写法:用ModuleList
self.attention_layers = nn.ModuleList([nn.Linear(768, 320) for _ in range(24)])

更坑的是某些自定义算子,比如xformers的memory-efficient attention,DeepSpeed识别不了,直接跳过。这时候只能手动分片:

# 手动分片参数
def shard_parameter(param, rank, world_size):
    # 按行分片
    rows = param.size(0)
    rows_per_rank = rows // world_size
    start = rank * rows_per_rank
    end = (rank + 1) * rows_per_rank if rank != world_size - 1 else rows
    return param[start:end]

# 在模型初始化时候
if deepspeed.zero.GatheredParameters.enabled():
    for name, param in model.named_parameters():
        if "attention" in name:
            sharded = shard_parameter(param, rank, world_size)
            param.data = sharded

调参不是碰运气,是科学"喂饭"

学习率怎么随GPU数量缩放?别再瞎猜了

很多教程说:batch size2,学习率也2。错!这是单机单卡的经验。分布式训练时候,总batch size=单卡卡数,但学习率不能线性卡数,不然loss直接飞。

正确公式:lr = base_lr * sqrt(world_size * batch_size_per_gpu / base_batch_size)

# 自动计算缩放学习率
base_lr = 1e-4
base_batch_size = 4  # 单机时候的batch size
world_size = dist.get_world_size()
batch_size_per_gpu = 4

lr = base_lr * math.sqrt(world_size * batch_size_per_gpu / base_batch_size)
print(f"缩放后学习率: {lr:.2e}")

实测:SDXL训练,8卡时候用线性*8,loss在第三个step飙到NaN;用sqrt缩放,稳定收敛。原理是:分布式时候梯度平均,方差变小,需要更小lr来稳定。

混合精度训练(AMP)开不开?FP16 vs BF16实测对比

先看数据:SDXL训练,FP16速度比FP32快1.8倍,BF16快1.6倍。但FP16容易上溢下溢,loss经常变NaN;BF16范围大,几乎不炸。

# FP16配置(容易炸)
fp16_config = {
    "fp16": {
        "enabled": True,
        "loss_scale": 0,  # 动态scale
        "loss_scale_window": 1000,
        "hysteresis": 2,
        "min_loss_scale": 1
    }
}

# BF16配置(稳但略慢)
bf16_config = {
    "bf16": {
        "enabled": True
    }
}

实测SDXL:FP16训练到第500步,loss突然从0.12飙到NaN;BF16全程稳定,最终loss还低3%。所以除非你的GPU不支持BF16(比如V100),否则直接上BF16。

梯度裁剪、梯度累积这些老招数在分布式下还灵吗?

灵,但要改!单机时候梯度裁剪直接:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

分布式时候,每张卡只存部分梯度,裁剪时候要先同步:

# 分布式梯度裁剪
def clip_grad_norm_distributed(model, max_norm):
    # 先同步所有梯度
    grads = []
    for p in model.parameters():
        if p.grad is not None:
            grads.append(p.grad.data)
    
    # 计算全局范数
    total_norm = torch.norm(torch.stack([torch.norm(g) for g in grads]))
    dist.all_reduce(total_norm, op=dist.ReduceOp.SUM)
    total_norm = total_norm ** (1/2)
    
    # 裁剪
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for g in grads:
            g.mul_(clip_coef)
    return total_norm

梯度累积也要注意:分布式下,总batch size已经world_size,accumulation_steps要相应减少。比如你想达到总batch size=64,8卡时候单卡batch size=4,accumulation_steps=2就够了(48*2=64)。

数据加载也能拖垮整个集群

WebDataset真能解决I/O瓶颈?

能!特别是TB级数据时候。普通ImageFolder要扫描目录,LAION-5B的5B个文件,ls都要半小时。WebDataset把数据打包成tar,顺序读取,速度直接*10。

# WebDataset使用样例
import webdataset as wds

# 创建dataset
dataset = wds.WebDataset("/data/laion-5b/{000000..009999}.tar")
    .decode("rgb")  # 自动解码图片
    .to_tuple("jpg", "txt")  # 取出图片和文本
    .map(lambda img, txt: (transform(img), tokenizer(txt)))  # 预处理
    .shuffle(10000)  # 内存shuffle,比磁盘快
    .batched(4)  # 组batch

dataloader = wds.WebLoader(dataset, num_workers=4, batch_size=None)

实测:同样NVMe SSD,ImageFolder加载1000张图要8秒,WebDataset只要0.8秒。但注意:WebDataset的shuffle是内存shuffle,数据太大时候要配合外部shuffle:

# 训练前先把tar包随机重排
ls *.tar | shuf > shards.txt
# 训练时候按重排后的顺序读取
dataset = wds.WebDataset(shards.txt)

把LAION-5B这种怪物数据集喂给多机多卡的骚操作

LAION-5B原始数据240TB,全下载不现实。正确姿势:只下载你需要的子集。比如训二次元模型,先用CLIP过滤出anime分数>0.3的图片,剩下大概500GB,4台机器各下125GB,半小时搞定。

# 分布式下载脚本
import pandas as pd
from laion_util import download_shard

# 读取LAION-5B元数据
df = pd.read_parquet("laion5b-metadata.parquet")

# 过滤动漫风格
df = df[df["pwatermark"] < 0.8]  # 去水印
df = df[df["punsafe"] < 0.1]    # 去NSFW
df = df[df["aesthetic_score"] > 5.0]  # 美学分数
df = df[df["anime_score"] > 0.3]  # 动漫分数(自己预训练的classifier)

# 按rank分片
total_shards = world_size
shard_size = len(df) // total_shards
start = rank * shard_size
end = (rank + 1) * shard_size if rank != total_shards - 1 else len(df)
shard_df = df.iloc[start:end]

# 下载
download_shard(shard_df, output_dir=f"/data/laion-anime/shard{rank:05d}")

下载完用WebDataset打包:

# 每台机器打包自己的分片
find /data/laion-anime/shard00000 -name "*.jpg" | \
  xargs -P 32 -I {} bash -c 'img={}; txt=${img/.jpg/.txt}; \
  echo {} | tar -cf shard00000.tar -T -'

最后把各机器的tar汇总到共享存储,训练时候直接:

# 多机读取不同分片
shard_ids = list(range(world_size * shards_per_node))
my_shards = shard_ids[rank * shards_per_node : (rank + 1) * shards_per_node]
tar_files = [f"/shared/laion-anime/shard{i:05d}.tar" for i in my_shards]
dataset = wds.WebDataset(tar_files)

缓存、预取、Shuffle——别让CPU闲着看GPU干等

数据管道要像高铁一样无缝衔接。推荐配置:

# 终极数据加载配置
dataloader = DataLoader(
    dataset,
    batch_size=4,
    num_workers=8,  # 核心数-2
    prefetch_factor=4,  # 每个worker预取4个batch
    persistent_workers=True,  # epoch间不复用worker
    pin_memory=True,  # 直接锁页内存
    shuffle=True,
    multiprocessing_context='spawn',  # 避免fork冲突
)

更骚的是用RAMDisk:直接把第一批数据放内存:

# 创建16GB内存盘
mkdir /dev/shm/laion-cache
mount -t tmpfs -o size=16G tmpfs /dev/shm/laion-cache

# 训练前预热
cp /data/laion-anime/shard00000.tar /dev/shm/laion-cache/

实测:NVMe读取速度3GB/s,内存盘读取10GB/s,小型数据集直接飞起。但注意内存掉电就丢,训练完记得拷回磁盘。

监控和调试:别让训练变成盲盒

怎么一眼看出哪个节点在摸鱼?

最土的办法:ssh到每台机器看nvidia-smi。但30台时候你会疯。正确姿势:Prometheus + Grafana + DCGM:

# docker-compose.yml
version: "3"
services:
  dcgm-exporter:
    image: nvidia/dcgm-exporter:3.1.7-ubuntu20.04
    runtime: nvidia
    environment:
      - NVIDIA_VISIBLE_DEVICES=all
    ports:
      - "9400:9400"
    cap_add:
      - SYS_ADMIN

  prometheus:
    image: prom/prometheus
    volumes:
      - ./prometheus.yml:/etc/prometheus/prometheus.yml
    ports:
      - "9090:9090"

  grafana:
    image: grafana/grafana
    ports:
      - "3000:3000"
    environment:
      - GF_SECURITY_ADMIN_PASSWORD=admin

dcgm-exporter会暴露GPU利用率、显存、温度、功耗等指标,Prometheus拉取,Grafana画图。关键看三个:

  • GPU Utilization:低于80%就是有问题
  • GPU Memory Used:突然下降可能是OOM重启
  • GPU Power:功耗低=GPU在摸鱼
# 看平均利用率
avg(dcgm_gpu_utilization)

# 看哪个节点最慢
max by (instance) (dcgm_gpu_utilization)

用Prometheus + Grafana盯死每块GPU的利用率

更细粒度看训练进程:

# 训练代码里埋点
from prometheus_client import Gauge, start_http_server

gpu_util_gauge = Gauge('training_gpu_utilization', 'GPU utilization', ['rank'])
gpu_memory_gauge = Gauge('training_gpu_memory_gb', 'GPU memory used', ['rank'])

def log_gpu_stats(rank):
    while True:
        util = nvidia_ml_py3.getGpuUtilization(rate=1000).utilization
        memory = torch.cuda.memory_allocated() / 1e9
        gpu_util_gauge.labels(rank=rank).set(util)
        gpu_memory_gauge.labels(rank=rank).set(memory)
        time.sleep(10)

# 启动监控线程
threading.Thread(target=log_gpu_stats, args=(rank,), daemon=True).start()
start_http_server(8000 + rank)

然后在Grafana里画热力图,一眼看出哪张卡在偷懒。我曾经发现节点7的GPU 5利用率永远50%,ssh上去一看:网卡IRQ绑定错了,CPU 0处理所有网络中断,直接堵死。

日志打太多反而拖慢训练?学会"精准吐槽"

分布式训练30台机器,每台打100MB日志,一天就是72TB,存储费用比GPU还贵。正确姿势:分级日志+采样。

import logging

# 只让rank 0打详细日志
if rank == 0:
    logging.basicConfig(level=logging.INFO)
else:
    logging.basicConfig(level=logging.WARNING)

# 训练step用采样,每100步打一次
if step % 100 == 0 or step < 10:  # 前10步全打,方便调试
    logging.info(f"step {step}, loss: {loss.item():.4f}")

# 异常时候全节点打详细日志
try:
    loss = model(**batch).loss
except Exception as e:
    logging.exception(f"rank {rank} failed at step {step}, batch: {batch}")
    raise

更高级用结构化日志(JSON),方便Elasticsearch检索:

import json
import structlog

structlog.configure(
    processors=[
        structlog.stdlib.filter_by_level,
        structlog.stdlib.add_logger_name,
        structlog.stdlib.add_log_level,
        structlog.processors.TimeStamper(fmt="iso"),
        structlog.processors.JSONRenderer()
    ],
    context_class=dict,
    logger_factory=structlog.stdlib.LoggerFactory(),
)

logger = structlog.get_logger()
logger.info("training_step", step=step, loss=loss.item(), gpu_memory_gb=torch.cuda.memory_allocated()/1e9, rank=rank)

省电又省心的小技巧

训练中途断了?Checkpoint怎么存才不丢进度

最惨的:训练3天,第2.9天机房断电,checkpoint没存,直接哭死。正确姿势:每半小时存一次,覆盖存,别按epoch存。

# 自动checkpoint
def save_checkpoint(model, optimizer, epoch, step, loss, rank):
    if rank != 0:  # 只让主进程存
        return
    
    checkpoint = {
        "epoch": epoch,
        "step": step,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
        "rng_state": torch.get_rng_state(),
        "cuda_rng_state": torch.cuda.get_rng_state(),
    }
    
    # 先存临时文件,再rename,避免写一半崩了
    tmp_path = f"/checkpoints/sdxl-train-step{step}.pt.tmp"
    torch.save(checkpoint, tmp_path)
    os.rename(tmp_path, f"/checkpoints/sdxl-train-step{step}.pt")
    
    # 只保留最近5个
    checkpoints = sorted(glob("/checkpoints/sdxl-train-step*.pt"))
    for old in checkpoints[:-5]:
        os.remove(old)

# 训练循环
for epoch in range(start_epoch, 100):
    for batch in dataloader:
        # ... 训练代码 ...
        
        if step % 1000 == 0:  # 每1000步存一次
            save_checkpoint(model, optimizer, epoch, step, loss, rank)

注意:要存随机种子!分布式训练时候,每个进程的random、numpy、cuda种子都要存,不然恢复训练时候数据shuffle不一样,loss直接跳崖。

自动扩缩容:半夜便宜时多开几台,白天贵了就缩回去

云厂商的spot实例,半夜价格只要on-demand的30%。用Kubernetes + cluster-autoscaler自动扩缩:

# 训练任务
apiVersion: batch/v1
kind: Job
metadata:
  name: sdxl-train
spec:
  parallelism: 8
  template:
    spec:
      nodeSelector:
        node-type: spot-gpu  # 只调度到spot节点
      containers:
      - name: train
        image: my-sd-image
        resources:
          limits:
            nvidia.com/gpu: 8
        env:
        - name: NCCL_DEBUG
          value: "INFO"
      restartPolicy: Never

cluster-autoscaler配置:

# 自动扩缩容配置
apiVersion: v1
kind: ConfigMap
metadata:
  name: cluster-autoscaler-status
data:
  nodes.max: "32"
  nodes.min: "4"
  scale-down-delay-after-add: "10m"
  scale-down-unneeded-time: "10m"

晚上11点:spot价格$0.5/GPU,自动扩到32卡;早上9点:spot价格$2/GPU,缩回4卡。一个月省下来够吃10顿海底捞。

用Spot实例跑训练?小心被云厂商"拔电源"

Spot实例最坑的是:随时可能被回收!AWS给2分钟警告,GCP给30秒,阿里云直接秒删。必须做checkpoint+自动恢复。

# 检测spot回收信号
def spot_termination_handler(signum, frame):
    logging.warning("Received spot termination notice, saving checkpoint...")
    save_checkpoint(model, optimizer, epoch, step, loss, rank)
    sys.exit(0)

signal.signal(signal.SIGTERM, spot_termination_handler)

# 训练循环加检测
while True:
    try:
        # ... 训练代码 ...
    except KeyboardInterrupt:
        save_checkpoint(...)
        break

更保险:用Kubernetes的preStop hook:

lifecycle:
  preStop:
    exec:
      command: ["/bin/bash", "-c", "python save_checkpoint.py && sleep 10"]

这样K8s删除pod前,会先执行save_checkpoint,再给你10秒上传checkpoint到S3。实测:AWS回收spot实例,checkpoint成功保存率99%+。

最后唠点实在的

别迷信大厂方案,小团队也能玩转分布式

看过OpenAI训GPT-4的"万卡集群"新闻,是不是觉得分布式是土豪专属?错!我3人小团队,4台4090照样训出二次元SD模型,C站下载量10W+。关键是:循序渐进,先单机调通,再2卡,再4卡,别一口吃胖子。

小团队推荐路线:

  1. 本地4090 24G,训SD 1.5,batch size=1,gradient checkpointing+accumulation,先跑出能用的模型
  2. 上4台4090,DDP,batch size=16,训一周,效果质变
  3. 上DeepSpeed ZeRO-2,batch size=32,学习率sqrt缩放,再训三天,收敛更快
  4. 数据不够?上LAION-5B子集,WebDataset+过滤,500GB数据够用

有时候不是技术不行,是你没敢把batch size拉满

很多同学训SD,batch size=1,loss震荡像心电图,怪模型难训。其实是batch size太小,梯度方差爆炸。我实测:SDXL训练,batch size从1提到8,同样1000步,FID从12.3降到8.7,效果直接起飞。

怎么拉大batch size?三板斧:

  1. 梯度累积:显存不够时间凑,accumulation_steps=4,等价batch size*4
  2. 混合精度:FP16/BF16直接省一半显存,batch size*2
  3. DeepSpeed ZeRO:优化器状态分片,batch size再*2-3倍

记住:只要梯度同步正确,batch size越大越稳。别迷信"小batch更精细",Stable Diffusion是生成模型,大batch才能让判别器看到更全面的分布。

记住:分布式不是终点,是让你敢想更大模型的起点

训完SDXL,是不是觉得到头了?错!SDXL才3.5B参数,GPT-3有175B,Sora有20B(视频生成)。分布式训练让你敢想10B、50B、100B的模型。到时候不是"生成老婆",是"生成老婆的一生"——从青梅竹马到白发苍苍,一条prompt搞定。

下个项目计划:分布式训个10B的"Story Diffusion",输入一段故事,直接输出漫画。100卡A100,batch size=128,训一个月,数据用WebNovel+Pixiv,10TB应该够了。到时候欢迎各位来内测,生成你们自己的"三体"漫画版!

(全文完,键盘已冒烟,我去喝杯奶茶续命)

在这里插入图片描述

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐