最近在折腾图像生成模型时,我发现了一个有趣的现象:同样的文本提示词,有些模型生成的图片就是比别人更符合人类审美。

这背后的秘密是什么?答案是强化学习

今天我想分享一次真实的技术探索经历——如何将DanceGRPO这个多模态强化学习框架应用到FLUX图像生成模型上,让AI学会生成更符合人类偏好的图片。

1 什么是DanceGRPO?为什么它这么特别

在讲具体实践之前,我们先理解一个核心概念:GRPO(Group Relative Policy Optimization)。简单来说,它是一种强化学习算法,最初用于优化大语言模型,让模型的回答更符合人类偏好。

DanceGRPO的创新之处在于,它是首个将GRPO应用到视觉生成领域的框架。它支持多种生成范式(扩散模型和矫正流)、多种任务(文生图、文生视频、图生视频),并且可以无缝对接Stable Diffusion、FLUX等主流模型。

用个通俗的比喻:如果说传统训练方法是"填鸭式教育",那么强化学习就是"因材施教"。模型会生成多个版本的图片,系统给每个版本打分,然后告诉模型:"第3张图片最好,你应该多朝这个方向努力。"经过反复训练,模型就学会了生成更符合人类审美的图片。

2 核心原理三个阶段环环相扣

DanceGRPO的训练流程可以分为三个关键阶段,我用一个具体例子来说明:

阶段一:推理阶段——生成候选图片

假设我们输入提示词:"一只戴着盔甲的猫"。在这个阶段:

1. Policy模型(也就是FLUX生成模型)会从纯噪声开始,经过16步迭代去噪。

2. 每个提示词生成12张图片(默认配置),这12张图片整体相似但细节不同。

3. 为什么是12张?因为我们需要在一组内进行对比学习。

这里有个有趣的技术细节,去噪的步长是动态变化的。

一开始步长大(快速勾勒轮廓),后面步长小(精细刻画细节),就像画家先打草稿再细描。

推理阶段核心代码:

for step in range(16):  # 16步迭代去噪

    # 预测当前噪声成分

    noise_pred = policy_model(latents, timestep, prompt_embeds)

    # 去噪:当前状态 - 预测噪声

    latents = latents - step_size * noise_pred

    # 加入随机扰动,增加多样性

    if step < 15:

        latents += torch.randn_like(latents) * noise_scale

阶段二:Reward阶段——给图片打分

生成12张图片后,系统需要评判哪张更好,这里用到Reward模型(基于CLIP):

1. 把图片和文本分别编码成特征向量。

2. 计算两个向量的相似度,得到奖励分数。

3. 计算组内平均分,得出每张图的相对优势值(advantage)。

举个例子:假设12张图的分数分别是[0.8, 0.85, 0.75, ...],平均分是0.8,那么第二张图的advantage就是0.85-0.8=0.05,说明它比平均水平好。

Reward计算示例:

with torch.no_grad():

    # 提取图像和文本特征

    image_features = clip_model.encode_image(images)

    text_features = clip_model.encode_text(texts)

    # 计算相似度得分

    rewards = (image_features @ text_features.T).diagonal()

    # 计算相对优势

    advantage = rewards - rewards.mean()

阶段三:训练阶段——更新策略

这是最核心的部分,系统会对比"新策略"和"旧策略"的差异:

- 旧策略:生成这12张图时模型的参数状态。

- 新策略:更新梯度后模型的参数状态。

通过计算ratio(新旧策略概率比),结合advantage,就能算出loss值。

如果某个动作(比如"给猫加金色盔甲")获得了高分,系统会提升这个动作的概率;反之则抑制。

第一次计算loss时,新旧策略其实是一样的(因为还没更新),这时loss主要依赖advantage值,但处理4个样本后,策略开始分化,训练真正开始起作用。

3 实战挑战精度对齐的血泪史

理论很美好,实际很骨感。把GPU代码迁移到NPU上,最大的挑战是精度对齐

(1)固定随机性,让结果可复现

强化学习涉及大量随机过程:噪声初始化、去噪扰动、数据采样...要对齐精度,第一步是固定所有随机性:

# 固定全局随机种子

from msprobe.pytorch import seed_all

seed_all(mode=True)

# 固定通信随机性

export HCCL_DETERMINISTIC=TRUE

# 固定数据加载顺序

sampler = DistributedSampler(train_dataset, shuffle=False, seed=42)

# 固定初始噪声(在CPU生成,避免NPU/GPU差异)

input_latents = torch.randn((1, 16, 64, 64), 

                           dtype=torch.bfloat16, 

                           device="cpu").to(device)

(2)分阶段对齐,各个击破

我采用的策略是:先单独对齐推理、Reward阶段,再端到端对齐全流程。

推理阶段对齐:保存每个step生成的12张图片,肉眼对比GPU和NPU的差异。

一开始NPU生成的图片有明显色块(俗称"花屏"),后来发现是RoPE位置编码中的repeat_interleave算子问题。

Reward阶段对齐:模拟1000张图片的打分过程,对比GPU和NPU的reward值。最终绝对误差控制在0.015%以内。

训练阶段对齐:关注loss曲线和长稳reward scores,训练200步后,误差在5%以内算合格。

这里有个坑,某次替换融合算子后,loss和reward看起来都正常,但推理图片出现了花屏。所以永远不要只看数值,一定要看实际效果。

4 性能优化从419秒到315秒

精度对齐后,下一个挑战是性能。初始开箱性能是419秒/步,经过一系列优化,最终达到315秒/步,提升了约25%。

优化1:算子层面——repeat_interleave的坑

FLUX模型的RoPE位置编码会调用repeat_interleave算子。NPU的这个算子有个特性:首根轴repeat效率高,非首轴会触发额外的Transpose操作,耗时飙升。

解决方案很简单:调整repeat的维度顺序。

优化前:

freqs = freqs.repeat_interleave(2, dim=1)  # 在dim=1上repeat

优化后:

freqs = freqs.transpose(0, 1).repeat_interleave(2, dim=0).transpose(0, 1)

这一改动在A+X平台上带来84秒的收益(419s→335s)。

优化2:通信层面——增大带宽

分布式训练中,通信开销占比很高,通过调整HCCL_BUFFSIZE参数,增大通信缓存区:

export HCCL_BUFFSIZE=800  # 默认200M,调整为800M

这个优化在A+K平台收益13秒(365s→352s)。

优化3:调度层面——异步保存与多图推理

训练过程中需要保存图片用于观察效果,但同步保存会让NPU空闲等待,改为异步保存后,NPU可以继续执行Reward阶段。

我有一个小发现,推理阶段batch size设为1效率不高(NPU擅长大kernel计算)。

将batch size从1增加到4,一次前向生成4张图,性能提升明显(325s→315s)。

异步保存示例:

import threading

def save_images_async(images, paths):

    def _save():

        for img, path in zip(images, paths):

            img.save(path)

    thread = threading.Thread(target=_save)

    thread.start()

5 实战效果数据说话

经过200步训练,模型在多个维度都有提升:

- Reward分数:从0.75提升到0.82(提升约9%)。

- 主观效果:早期生成的浴室图片光线暗淡、细节模糊;训练后瓷砖纹理清晰、光影自然。

- 文本对齐:对于"wearing a metallic helmet"这样的细节描述,训练后模型能准确生成金属质感。

值得一提的是,NPU和GPU训练出的模型在下游任务推理时效果基本一致,证明精度对齐是成功的。

这次DanceGRPO实战让我深刻体会到,技术的魅力不仅在于理论的优雅,更在于解决实际问题的成就感。

当看到训练200步后生成的图片从模糊变清晰、从违和到自然,那种兴奋感是无法用语言描述的。

记住:遇到问题不要慌,固定随机性、分阶段验证、保存中间结果——这三板斧能解决大部分问题。

技术的道路没有捷径,但每一次实践都会让你更接近真相。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐