在Replicate上训练和部署DreamBooth模型

发布于2022年11月21日

2024年8月更新:实验性DreamBooth API已不再可用。请查看FLUX.1微调博客文章,了解效果更好的替代方案。

2023年8月更新:已为SDXL(Stable Diffusion最新版本)添加微调支持。下文描述的DreamBooth API仍然有效,但使用SDXL可以在更高分辨率下获得更好的效果。请查看SDXL微调博客文章开始使用,或继续阅读以使用旧的DreamBooth API。

生成式AI领域因DreamBooth而热闹非凡。这是一种训练特定对象或风格的Stable Diffusion的方法,可以创建生成这些对象或风格的个性化模型版本。只需三张图像即可训练模型,训练过程不到半小时。

值得注意的是,DreamBooth适用于人物,因此可以制作能够生成自己图像的Stable Diffusion版本。

人们已经使用DreamBooth制作了一些神奇的产品,例如Avatar AI和ProfilePicture.AI。

现在,也可以使用DreamBooth创建自己的项目。已经构建了一个API,允许在云中训练DreamBooth模型并对其运行预测。

只需三张训练图像,大约需要20分钟(取决于使用的迭代次数)。训练模型成本约为2.50美元。

训练DreamBooth模型

首先,获取API令牌并在终端中设置:

export REPLICATE_API_TOKEN=...

接下来,将训练数据收集为名为data/的目录中的一组JPEG文件并压缩:

zip -r data.zip data

将此zip文件放在可通过HTTP访问的位置。如果需要,可以使用API上传文件。运行以下三个命令:

RESPONSE=$(curl -X POST -H "Authorization: Bearer $REPLICATE_API_TOKEN" https://dreambooth-api-experimental.replicate.com/v1/upload/data.zip)

curl -X PUT -H "Content-Type: application/zip" --upload-file data.zip "$(jq -r ".upload_url" <<< "$RESPONSE")"

SERVING_URL=$(jq -r ".serving_url" <<< $RESPONSE)

然后,启动训练任务:

curl -X POST \
  -H "Content-Type: application/json" \
  -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
  -d '{
        "input": {
          "instance_prompt": "a photo of a cjw person",
          "class_prompt": "a photo of a person",
          "instance_data": "'"$SERVING_URL"'",
          "max_train_steps": 2000
        },
        "model": "yourusername/yourmodel",
        "trainer_version": "cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa",
        "webhook_completed": "https://example.com/dreambooth-webhook"
      }' \
  https://dreambooth-api-experimental.replicate.com/v1/trainings

需要设置以下参数:

  • instance_prompt:用于描述训练图像的提示,格式为a [identifier] [class noun],其中identifier是某个稀有标记。在上面的示例中,使用了cjw,但可以使用任何字符串。为获得最佳结果,使用包含三个Unicode字符且无空格的标识符。
  • class_prompt:正在训练的图像的更广泛类别提示,格式为a [class noun]。这用于生成类似训练数据的其他图像,以避免过拟合。
  • instance_data:训练数据的URL。
  • max_train_steps:要运行的训练步骤数。较少的步骤使其运行更快但通常质量较差,反之亦然。
  • model:在Replicate上给模型的名称,格式为username/modelname。例如,bfirsh/bfirshbooth。如果模型尚不存在,Replicate会自动创建。
  • trainer_version:要使用的DreamBooth和Stable Diffusion版本。有关更多详细信息,请参阅下面的"版本"部分。
  • webhook_completed:作业完成时调用的webhook。(可选。)

在后台,这运行replicate/dreambooth模型。该模型的任何输入都可以在input对象中传递。

API响应如下对象:

{
  "id": "rrr4z55ocneqzikepnug6xezpe",
  "input": {
    "instance_prompt": "photo of a cjw person",
    "class_prompt": "photo of a person",
    "instance_data": "https://replicate.delivery/pbxt/HoUeWsrtTTCJEpKGdLKqIYTfo8nbUTSNs565MkGxEstjfwKt/data.zip",
    "max_train_steps": 2000
  },
  "model": "yourusername/yourmodel",
  "status": "starting",
  "trainer_version": "cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa",
  "webhook_completed": "https://example.com/dreambooth-webhook"
}

可以通过调用GET /v1/trainings/<id>获取训练作业的状态:

curl -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
  https://dreambooth-api-experimental.replicate.com/v1/trainings/rrr4z55ocneqzikepnug6xezpe

它响应相同的对象:

{
  "id": "rrr4z55ocneqzikepnug6xezpe",
  "input": {
    "instance_prompt": "photo of a cjw person",
    "class_prompt": "photo of a person",
    "instance_data": "https://replicate.delivery/pbxt/HoUeWsrtTTCJEpKGdLKqIYTfo8nbUTSNs565MkGxEstjfwKt/data.zip",
    "max_train_steps": 2000
  },
  "model": "yourusername/yourmodel",
  "status": "succeeded",
  "trainer_version": "cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa",
  "webhook_completed": "https://example.com/dreambooth-webhook",
  "version": "8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776"
}

这是发送到webhook的相同对象。

运行训练好的模型

当训练过程成功完成后,它将模型推送到Replicate。

可以像使用Replicate上的任何其他模型一样运行该模型,使用网站或API。

要在网站上运行,请转到仪表板,然后单击"models"。

新模型默认是私有的,仅对您可见。如果希望任何人都能查看和运行模型,可以在模型页面的"Settings"选项卡中将其公开。

要作为API运行模型,首先需要获取版本ID。这可以在模型页面的"API"选项卡上找到,或者训练API响应中的version字段。

然后,可以进行API调用:

curl -X POST \
  -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
  -d '{
        "input": {
          "prompt": "painting of cjw by andy warhol",
        },
        "version": "8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776",
      }' \
  https://api.replicate.com/v1/predictions

或者,使用Python:

import replicate
replicate.run(
  "yourusername/yourmodel:8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776",
  input={"prompt": "painting of cjw by andy warhol"},
)

要了解有关在Replicate上运行模型的更多信息,请查看Python入门指南或HTTP API参考。

版本

默认情况下,DreamBooth训练Stable Diffusion 1.5模型。该模型往往对DreamBooth效果更好,因为它包含更多不同的风格。

如果想使用其他版本,可以使用trainer_version选项选择不同的版本。以下是支持的版本:

  • Stable Diffusion 1.5cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa
  • 自定义检查点9c41656f8ae2e3d2af4c1b46913d7467cd891f2c1c5f3d97f1142e876e63ed7a
  • Stable Diffusion 2.1-based5e058608f43886b9620a8fbb1501853b8cbae4f45c857a014011c86ee614ffb

要查找其他可用版本,请查看DreamBooth训练器的发布说明。

后续步骤

如果对此有任何疑问,请加入Discord中的#dreambooth频道。

训练愉快!🚂
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)或者 我的个人博客 https://blog.qife122.com/
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐