FFaceNeRF模块

论文《FFaceNeRF: Few-shot Face Editing in Neural Radiance Fields》

深度交流Q裙:1051849847
全网同名 【大嘴带你水论文】 B站定时发布详细讲解视频

视频地址,点击查看论文详细讲解,每日更新:
https://b23.tv/zdapaC3

详细代码见文章最后

在这里插入图片描述

1、作用

FFaceNeRF旨在解决现有基于NeRF的3D人脸编辑方法严重依赖固定布局的预训练分割蒙版、导致用户控制能力有限的问题。它使用户能够根据特定的编辑需求(如虚拟试妆、医疗整形预览等)自由定义和使用新的蒙版布局,而无需收集和标注大规模数据集,极大地提升了3D人脸编辑的灵活性和实用性。

在这里插入图片描述

2、机制

  1. 几何适配器 (Geometry Adapter) : 在预训练的几何解码器(用于生成固定布局的分割图)之后,添加一个轻量级的MLP网络作为几何适配器。该适配器负责将固定布局的输出调整为用户期望的、任意布局的分割蒙版。
  2. 特征注入 (Feature Injection) : 为了在适应新蒙版时保留丰富的几何细节,模型将预训练模型中的三平面特征(tri-plane feature)和视角方向(view direction)直接注入到几何适配器中,弥补了预训练解码器可能丢失的信息。
  3. 三平面增强的潜在混合 (LMTA) : 这是一种为小样本学习设计的数据增强策略。通过在生成器的潜在空间中混合不同层的潜在编码,可以在保持核心语义信息(如人脸结构)不变的同时,生成多样化的训练样本(如改变色调、饱和度),有效避免了在仅有10个样本的情况下发生的过拟合。
  4. 基于重叠的优化 (Overlap-based Optimization) : 在训练和推理过程中,除了使用传统的交叉熵损失外,还引入了基于DICE系数的重叠损失(overlap loss)。这种损失函数对小区域的变化更敏感,确保了即使在编辑精细区域(如瞳孔、鼻翼)时也能实现精确对齐和稳定生成。

3、独特优势

  1. 小样本高效学习 : 最显著的优势是其小样本(Few-shot)能力,仅需约10个带有自定义蒙版的样本即可完成训练,快速适应新的编辑任务,极大降低了数据和时间成本。
  2. 高度的灵活性和控制力 : 用户不再受限于固定的分割类别,可以为任何感兴趣的面部区域创建蒙版并进行编辑,实现了前所未有的控制自由度。
  3. 精细区域的精准编辑 : 通过基于重叠的优化策略,模型能够精确地处理和编辑微小面部特征,解决了传统方法在小区域编辑上容易失败的痛点。
  4. 保持身份和非编辑区域 : 在编辑特定区域时,能够很好地保持人脸的身份特征以及未编辑区域的图像内容,生成结果自然且保真度高。

4、代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class MockTriPlaneGenerator(nn.Module):
    """ NeRF三平面生成器,用于生成基础的分割图。 """
    def __init__(self, num_classes=10):
        super().__init__()
        self.num_classes = num_classes
        # 一个简单的卷积层,模拟从三平面特征到分割图的解码过程
        self.decoder = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, triplane_features, view_direction):
        # 解码过程,返回一个随机的分割图
        batch_size = triplane_features.shape[0]
        # 返回一个随机的、固定布局的分割图
        return torch.randn(batch_size, self.num_classes, 64, 64)

class MockDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples=10, num_classes=12):
        self.num_samples = num_samples
        self.num_classes = num_classes

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 生成模拟数据
        triplane_features = torch.randn(1, 32, 256, 256) # 模拟三平面特征
        view_direction = torch.randn(1, 3) # 模拟视角方向
        # 模拟用户自定义的目标蒙版
        target_mask = torch.randint(0, self.num_classes, (1, 64, 64), dtype=torch.long)
        return triplane_features, view_direction, target_mask

# ----------------------------------------------------------------------------
# 核心代码实现 (Core Implementation)
# ----------------------------------------------------------------------------

class GeometryAdapter(nn.Module):
    """ 几何适配器,将固定布局的分割图调整为用户自定义的布局。 """
    def __init__(self, input_channels, output_channels):
        super().__init__()
        # 一个轻量级的MLP,用于适配几何特征
        self.adapter = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(64, output_channels, kernel_size=1)
        )

    def forward(self, x):
        return self.adapter(x)

class FFaceNeRF(nn.Module):
    """ FFaceNeRF 模块,集成了预训练生成器和几何适配器。 """
    def __init__(self, pretrained_generator, num_custom_classes):
        super().__init__()
        self.pretrained_generator = pretrained_generator
        # 几何适配器的输入通道数等于预训练生成器的输出类别数
        self.adapter = GeometryAdapter(pretrained_generator.num_classes, num_custom_classes)

    def forward(self, triplane_features, view_direction):
        # 首先,使用预训练生成器获取固定布局的分割图
        fixed_layout_seg = self.pretrained_generator(triplane_features, view_direction)
        # 然后,通过几何适配器将其调整为自定义布局
        custom_layout_seg = self.adapter(fixed_layout_seg.detach()) # detach以冻结预训练部分
        return custom_layout_seg

def dice_loss(pred, target, smooth=1e-5):
    """ 计算DICE损失,对小区域优化更有效。 """
    pred = F.softmax(pred, dim=1)
    # 将target转换为one-hot编码
    target_one_hot = F.one_hot(target.squeeze(1), num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()
    
    intersection = torch.sum(pred * target_one_hot, dim=(2, 3))
    union = torch.sum(pred, dim=(2, 3)) + torch.sum(target_one_hot, dim=(2, 3))
    
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - dice.mean()

# ----------------------------------------------------------------------------
# 运行示例 (Runnable Example)
# ----------------------------------------------------------------------------

if __name__ == '__main__':
    # --- 1. 初始化参数和模型 ---
    num_fixed_classes = 10  # 预训练模型支持的固定类别数
    num_custom_classes = 12 # 用户自定义的类别数(例如,眉毛、上唇、下唇等)
    num_samples = 10        # 小样本数量
    epochs = 50             # 训练轮次

    # 初始化模拟的预训练生成器和FFaceNeRF模型
    pretrained_generator = MockTriPlaneGenerator(num_classes=num_fixed_classes)
    fface_nerf = FFaceNeRF(pretrained_generator, num_custom_classes)
    optimizer = torch.optim.Adam(fface_nerf.adapter.parameters(), lr=0.001)

    # --- 2. 创建模拟数据集 ---
    dataset = MockDataset(num_samples=num_samples, num_classes=num_custom_classes)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)

    # --- 3. 训练循环 ---
    for epoch in range(epochs):
        total_loss = 0
        for triplane_features, view_direction, target_mask in dataloader:
            optimizer.zero_grad()

            # 获取模型预测的自定义分割图
            predicted_seg = fface_nerf(triplane_features, view_direction)

            # 计算损失(交叉熵 + DICE损失)
            loss_ce = F.cross_entropy(predicted_seg, target_mask.squeeze(1))
            loss_dice = dice_loss(predicted_seg, target_mask)
            loss = loss_ce + 0.5 * loss_dice # 组合损失

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], 平均损失: {total_loss / len(dataloader):.4f}")

    # --- 4. 模拟推理 ---
    print("\n进行一次模拟推理...")
    with torch.no_grad():
        # 取一个样本进行测试
        test_features, test_view_dir, ground_truth_mask = next(iter(dataloader))
        predicted_mask_logits = fface_nerf(test_features, test_view_dir)
        predicted_mask = torch.argmax(predicted_mask_logits, dim=1)

        print(f"输入特征尺寸: {test_features.shape}")
        print(f"预测蒙版尺寸: {predicted_mask.shape}")
        print(f"预测蒙版中的类别: {torch.unique(predicted_mask)}")

详细代码 gitcode地址:https://gitcode.com/2301_80107842/research

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐