PETALface代码链接:https://github.com/Kartik-3004/PETALface
PETALface论文链接:https://arxiv.org/abs/2412.07771
PETALface预训练权重链接:https://huggingface.co/kartiknarayan/PETALface

一、代码解读

1. 首先我们需要进行代码结构的解读

核心配置与入口文件

config.py: 项目核心配置文件,定义训练参数(如学习率、批处理大小)、数据集路径、模型超参数等,是整个项目的参数中枢。
train.py / train_iqa.py: 训练入口脚本。train.py 用于基础模型训练,train_iqa.py 专门针对集成图像质量评估(IQA)的模型训练,包含训练循环、损失计算、参数更新等逻辑。
losses.py: 实现损失函数,为模型训练提供分类损失计算支持。
lr_scheduler.py: 学习率调度器,控制训练过程中学习率的衰减策略,优化模型收敛效果。

骨干网络模块(backbones/)存放各类特征提取网络结构,支持不同模型架构的灵活切换,包含:

iresnet.py / iresnet2060.py: 基于残差网络(ResNet)的改进版本,适用于人脸特征提取。
swin_models.py / swin_models_iqa.py:Swin Transformer 模型实现,其中 swin_models_iqa.py 集成了图像质量评估相关的适配层。
vit.py / vit_iqa.py: 视觉 Transformer(ViT)模型,vit_iqa.py 针对低分辨率场景加入了质量感知模块。
lora_layers.py: 核心创新点实现,定义了 LoRA(Low-Rank Adaptation)低秩适配层,支持参数高效微调,仅更新少量参数即可适配低分辨率数据。
mobilefacenet.py: 轻量级 MobileFaceNet 模型,适合资源受限场景。

工具函数(utils/)

提供训练、评估、日志等辅助功能:
evaluate_utils.py: 评估工具函数。
utils_logging.py:日志管理工具,记录训练过程中的损失、精度等信息,支持可视化工具集成。
utils_distributed_sampler.py: 分布式训练采样器,支持多 GPU 环境下的数据分发。
utils_callbacks.py: 训练回调函数,如模型保存、早停策略等。
plot.py: 可视化工具,绘制训练曲线、特征分布等图表。

数据集处理(dataset/)

dataset.py: 数据集加载与预处理逻辑,支持读取 WebFace4M、TinyFace 等数据集,包含图像加载、增强、对齐等操作。
label_to_idx.py: 标签映射工具,将人脸身份标签转换为模型可识别的索引,处理多数据集的标签一致性。

头部网络(heads/)

partial_fc.py: 实现 Partial FC 头部,用于大规模人脸分类任务,优化类别不平衡问题,配合 margin 损失提升特征判别性。

验证模块(validation_*/)

针对不同数据集的验证逻辑,用于评估模型在高低分辨率场景下的性能:
validation_ijb/: 针对 IJB-B、IJB-C 等混合质量数据集的验证,包含 eval_ijb.py 等脚本,实现特征提取与比对。
validation_hq/: 高分辨率数据集(如 LFW、CFP)的验证,测试模型在高清场景下的泛化能力。
validation_lq/: 低分辨率数据集(如 TinyFace、BRIAR、IJB-S)的验证,核心文件包括:
PFE/ijbs.py: 解析 IJB-S 数据集协议,划分 gallery(图库)和 probe(探针),组织验证模板。
PFE/utils.py: 验证过程中的日志记录工具,支持 TensorBoard 等可视化 summary 写入。
脚本与文档(scripts/ / docs/)
scripts/: 存放训练、评估的脚本示例,简化命令行操作。
docs/: 项目文档,包括算法框架图、动机介绍、使用说明等,index.html 为项目网页,可视化展示核心思想与实验结果。
其他文件
environment.yml: conda 环境配置文件,定义项目依赖(如 PyTorch、OpenCV 等),一键搭建运行环境。
README.md: 项目说明文档,包含动机、贡献、安装步骤、数据集结构等关键信息。
LICENSE:MIT 许可证,明确代码的使用权限与限制。

整体逻辑总结

PETALface 以 “参数高效迁移学习” 为核心,通过 backbones/lora_layers.py 实现 LoRA 模块,结合 train_iqa.py 中的图像质量加权策略,在低分辨率数据集上微调时仅更新少量参数,避免灾难性遗忘。数据集处理、模型训练、多场景验证的模块化设计,使其能高效支持高低分辨率人脸识别任务的研究与部署。

2. 在清楚了代码结构之后,我们看一下这个项目是怎么进行训练的

由于我们的目标是低质量的人脸识别,所以主要分析的是train_iqa.py这个代码:
下面按「代码结构+核心功能+关键细节」的顺序,逐部分拆解解释:

一、整体代码结构

代码遵循「初始化→配置→数据加载→模型搭建→训练循环→保存验证」的标准深度学习训练流程,可分为 8 个核心模块:

  1. 依赖导入与分布式训练初始化
  2. 图像质量权重生成函数(generate_alpha
  3. 主函数入口(main):包含所有训练逻辑
    • 环境配置与日志初始化
    • 数据加载(训练集)
    • 模型搭建(骨干网络+LoRA+分类头)
    • 优化器与学习率调度器初始化
    • 训练循环(前向传播+反向传播+参数更新)
    • 验证与模型保存
  4. 脚本入口(if __name__ == "__main__"
二、逐模块详细解释
1. 依赖导入与分布式训练初始化
import argparse
import logging
import os
# ... 其他依赖导入(略)

assert torch.__version__ >= "1.12.0", "torch版本需≥1.12.0"

# 分布式训练参数(从环境变量读取,由torchrun或slurm启动时自动设置)
rank = int(os.environ["RANK"])  # 当前进程的全局排名(多GPU时区分不同GPU)
local_rank = int(os.environ["LOCAL_RANK"])  # 当前进程的本地排名(单节点内的GPU编号)
world_size = int(os.environ["WORLD_SIZE"])  # 总进程数(=总GPU数)
distributed.init_process_group("nccl")  # 初始化分布式训练(使用NCCL通信后端,GPU间高效通信)
  • 核心作用:搭建多GPU分布式训练环境,让多个GPU协同训练(提升速度、处理更大批量数据)。
  • 关键依赖torch.distributed 是PyTorch分布式训练的核心模块,nccl 是NVIDIA的GPU通信协议,效率最高。
2. 核心工具函数:generate_alpha(图像质量权重生成)
def generate_alpha(img, iqa, thresh):
    device = img.device  # 获取图像所在设备(GPU)
    BS, C, H, W = img.shape  # 批量大小、通道数、高、宽
    alpha = torch.zeros((BS, 1), dtype=torch.float32, device=device)  # 初始化质量权重(每个样本1个权重)

    score = iqa(img)  # 调用IQA模型计算每张图像的质量分数(0~1,分数越高质量越好)
    threshold = thresh  # 质量阈值(由参数--threshold指定)
    for i in range(BS):
        if score[i] == threshold:
            alpha[i] = 0.5  # 分数等于阈值时,权重为0.5
        elif score[i] < threshold:  # 低质量图像(分数<阈值)
            alpha[i] = 0.5 - (threshold - score[i])  # 权重随分数降低而减小(最低接近0)
        else:  # 高质量图像(分数>阈值)
            alpha[i] = 0.5 + (score[i] - threshold)  # 权重随分数升高而增大(最高接近1)
    return alpha
  • 核心目标:将图像质量分数转化为「模型可理解的权重 alpha」,实现「质量感知的特征提取」。
  • 逻辑本质:高质量图像给更高权重(模型更依赖其特征),低质量图像给更低权重(降低其对训练的干扰),最终缩小高低分辨率图像的域差异。
  • 依赖iqa 是通过 pyiqa 库加载的预训练质量评估模型(如BRISQUE、CNNIQA)。
3. 主函数 main:训练核心逻辑
(1)环境配置与日志初始化
def main(args):
    setup_seed(seed=args.seed, cuda_deterministic=False)  # 设置随机种子,保证实验可复现
    torch.cuda.set_device(local_rank)  # 将当前进程绑定到指定GPU(local_rank)
    os.makedirs(args.output, exist_ok=True)  # 创建模型保存目录
    init_logging(rank, args.output)  # 初始化日志(仅主进程rank=0打印日志,避免多GPU重复输出)

    # TensorBoard/WandB可视化(仅主进程启用,用于监控训练损失、精度)
    summary_writer = SummaryWriter(...) if rank == 0 else None
    wandb_logger = wandb.init(...) if args.using_wandb and rank == 0 else None
  • 关键细节setup_seed 固定CPU/GPU的随机种子,确保每次训练结果一致;rank==0 是主进程,负责日志、可视化、模型保存等核心操作,其他进程仅负责计算。
(2)数据加载
train_loader = get_dataloader(
    args.rec,  # 数据集路径
    local_rank,  # 本地GPU编号(分布式采样用)
    args.batch_size,  # 单GPU批次大小
    args.image_size,  # 图像尺寸(如120×120,低分辨率数据)
    args.dali,  # 是否使用DALI加速数据加载(NVIDIA的高效数据加载库)
    # ... 其他数据增强、workers参数
)
  • 核心作用:读取低分辨率人脸数据集(如TinyFace、BRIAR),并进行预处理(缩放、归一化、数据增强),返回批量数据迭代器。
  • 分布式适配get_dataloader 内部会创建分布式采样器,确保多GPU间数据不重复、不遗漏。
(3)模型搭建:骨干网络 + LoRA + 分类头
# 1. 搭建骨干网络(支持LoRA和IQA适配,如Swin Transformer/IQA版本)
backbone = get_model(
    args.network,  # 骨干网络类型(如swin_256new_iqa)
    dropout=0.0,
    fp16=args.fp16,  # 是否启用混合精度训练(加速训练、减少显存占用)
    num_features=args.embedding_size,  # 特征向量维度(默认1024)
    r=args.lora_rank,  # LoRA低秩矩阵的秩(如8)
    use_lora=args.use_lora  # 是否启用LoRA
).cuda()

# 2. 包装为分布式数据并行(DDP),支持多GPU训练
backbone = torch.nn.parallel.DistributedDataParallel(
    module=backbone,
    device_ids=[local_rank],
    find_unused_parameters=True  # 允许部分参数不参与计算(如LoRA冻结其他层时)
)
backbone.register_comm_hook(None, fp16_compress_hook)  # FP16通信压缩,加速GPU间数据传输

# 3. 搭建分类头(Partial FC)和损失函数(CombinedMarginLoss=ArcFace/CosFace)
margin_loss = CombinedMarginLoss(...)  # 带margin的损失函数,增强特征判别性
head = get_head(
    args.head,  # 分类头类型(默认partial_fc)
    margin_loss=margin_loss,
    embedding_size=args.embedding_size,  # 输入特征维度(1024)
    num_classes=args.num_classes  # 数据集类别数(如TinyFace为2570)
).cuda()
  • 核心创新点:骨干网络支持接收 alpha 质量权重(backbone(img, alpha)),结合LoRA层实现「质量自适应的参数高效微调」。
  • Partial FC:针对大规模数据集(如WebFace4M)的高效分类头,避免全连接层参数爆炸(如百万类别时全连接层参数过大)。
(4)LoRA微调配置(核心参数高效逻辑)
if args.use_lora:
    # 加载高分辨率预训练模型权重
    weights_path = os.path.join(args.load_pretrained, f"checkpoint_gpu_{rank}.pt")
    if os.path.exists(weights_path):
        dict_checkpoint = torch.load(weights_path)
        backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"], strict=False)
    else:
        dict_checkpoint = torch.load(os.path.join(args.load_pretrained, "model.pt"))
        backbone.module.load_state_dict(dict_checkpoint, strict=False)

    # 冻结骨干网络大部分参数(仅LoRA层可训练)
    for p in backbone.parameters():
        p.requires_grad = False  # 冻结所有参数
    for name, p in backbone.named_parameters():
        if 'trainable_lora' in name:  # 仅LoRA相关参数解冻(可训练)
            p.requires_grad = True

    # 分类头参数全部可训练
    for p in head.parameters():
        p.requires_grad = True
  • 核心逻辑:LoRA微调的核心是「冻结预训练骨干网络,仅训练少量低秩矩阵参数」,避免灾难性遗忘(忘记高分辨率图像的特征提取能力),同时适配低分辨率数据。
  • 参数占比:仅训练约0.48%的总参数(LoRA层的A、B矩阵),大幅降低计算成本和过拟合风险。
(5)优化器与学习率调度器
# 选择优化器(SGD或AdamW),仅优化可训练参数(LoRA层+分类头)
if args.optimizer == "adamw":
    opt = torch.optim.AdamW(
        params=[
            {"params": filter(lambda p: p.requires_grad, backbone.parameters())},  # LoRA可训练参数
            {"params": head.parameters()}  # 分类头参数
        ],
        lr=args.lr,  # 学习率(微调时通常较小,如0.0005)
        weight_decay=args.weight_decay  # 权重衰减,防止过拟合
    )

# 学习率调度器(预热+多项式衰减)
args.warmup_step = args.num_image // args.total_batch_size * args.warmup_epoch  # 预热步数
args.total_step = args.num_image // args.total_batch_size * args.num_epoch  # 总步数
lr_scheduler = PolynomialLRWarmup(
    optimizer=opt,
    warmup_iters=args.warmup_step,  # 预热阶段(学习率从0升到初始值)
    total_iters=args.total_step  # 总阶段(学习率多项式衰减)
)
  • 关键设计:预热阶段避免初始高学习率破坏预训练权重,多项式衰减让训练后期学习率逐步降低,稳定收敛。
(6)训练循环(核心!前向→反向→更新)
# 加载IQA模型(BRISQUE或CNNIQA)
if args.iqa == "brisque":
    iqa = pyiqa.create_metric('brisque').cuda()
    threshold = args.threshold  # 质量阈值(由用户指定,如0.3)

# 迭代训练轮次
for epoch in range(start_epoch, args.num_epoch):
    # 分布式采样器更新epoch(确保每轮数据打乱)
    if isinstance(train_loader, DataLoader):
        train_loader.sampler.set_epoch(epoch)
    
    # 迭代每个批次
    for _, (img, local_labels) in enumerate(train_loader):
        global_step += 1

        # 1. 计算图像质量权重alpha(核心!IQA融入训练)
        alpha = generate_alpha(img.clone(), iqa, threshold)  # img.clone()避免修改原始图像
        
        # 2. 前向传播:图像+质量权重 → 特征 → 损失
        local_embeddings = backbone(img, alpha)  # 骨干网络输出1024维特征(融入质量权重)
        loss = head(local_embeddings, local_labels)  # 分类头计算损失(CombinedMarginLoss)

        # 3. 反向传播与参数更新
        if args.fp16:  # 混合精度训练(加速+省显存)
            amp.scale(loss).backward()  # 损失缩放,避免梯度下溢
            if global_step % args.gradient_acc == 0:  # 梯度累积(模拟更大批次)
                amp.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)  # 梯度裁剪,防止梯度爆炸
                amp.step(opt)  # 更新参数
                amp.update()
                opt.zero_grad()  # 清空梯度
        else:  # 普通精度训练
            loss.backward()
            if global_step % args.gradient_acc == 0:
                torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
                opt.step()
                opt.zero_grad()
        
        # 4. 学习率更新
        lr_scheduler.step()

        # 5. 日志记录与可视化(损失、学习率)
        with torch.no_grad():
            loss_am.update(loss.item(), 1)  # 平均损失统计
            callback_logging(...)  # 打印训练日志
            if wandb_logger:
                wandb_logger.log({"Loss": loss.item()})  # WandB可视化

        # 6. 定期验证(评估模型在验证集上的性能)
        if global_step % args.verbose == 0 and global_step > 0:
            callback_verification(global_step, backbone)  # 验证LFW/IJB-S等数据集
  • 核心流程拆解
    ① 对每个批次图像,先计算质量权重 alpha
    ② 骨干网络接收「图像+alpha」,输出质量自适应的特征;
    ③ 分类头用Margin损失优化特征判别性;
    ④ 仅更新LoRA层和分类头的参数,冻结其他骨干网络参数;
    ⑤ 定期验证模型性能,确保训练不跑偏。
(7)模型保存与分布式训练收尾
# 保存检查点(所有GPU的参数,用于断点续训)
if args.save_all_states:
    checkpoint = {
        "epoch": epoch + 1,
        "state_dict_backbone": backbone.module.state_dict(),  # 骨干网络参数
        "state_dict_softmax_fc": head.state_dict(),  # 分类头参数
        "state_optimizer": opt.state_dict(),  # 优化器状态
        "state_lr_scheduler": lr_scheduler.state_dict()  # 调度器状态
    }
    torch.save(checkpoint, os.path.join(args.output, f"checkpoint_gpu_{rank}.pt"))

# 主进程保存最终模型(供推理/验证使用)
if rank == 0:
    torch.save(backbone.module.state_dict(), os.path.join(args.output, "model.pt"))
    if wandb_logger:
        wandb_logger.log_artifact(...)  # WandB保存模型 artifacts

# 分布式训练收尾(销毁进程组)
torch.distributed.barrier()  # 等待所有GPU完成
destroy_process_group()
  • 关键细节backbone.module 是获取DDP包装后的原始骨干网络(DDP会对模型进行封装,需通过.module访问原始模型)。
三、核心亮点与项目创新的对应

这段代码完美体现了PETALface的两大核心创新:

  1. 参数高效微调(LoRA):通过冻结骨干网络、仅训练LoRA层,解决低样本低分辨率数据的灾难性遗忘问题;
  2. 质量感知加权(IQA+alpha):通过IQA模型计算质量分数,动态调整LoRA权重,缩小高低分辨率图像的域差异。
四、关键参数总结(从命令行/配置文件传入)
参数名 作用 示例值
--network 骨干网络类型(需带IQA) swin_256new_iqa
--use_lora 是否启用LoRA微调 True
--lora_rank LoRA低秩矩阵的秩 8
--load_pretrained 预训练模型路径 ./pretrained
--iqa 图像质量评估模型 brisque
--threshold 质量分数阈值(控制alpha权重范围) 0.3
--num_classes 数据集类别数(如TinyFace为2570) 2570
--batch_size 单GPU批次大小 8
--lr 学习率(微调时较小) 0.0005

通过调整这些参数,可适配不同的低分辨率数据集和训练需求。

3. 看一下模型的网络结构

PETALface 支持多种成熟的人脸识别骨干网络作为基础模型,例如IResNet ,mobilefacenet等;具体的模型列表可以在backbones目录中的__init__.py里面找到;

from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
from .mobilefacenet import get_mbf


def get_model(name, **kwargs):
    # resnet
    if name == "r18":
        r = kwargs.pop("r", 4)
        scale = kwargs.pop("scale", 1)
        use_lora = kwargs.pop('use_lora', False)
        return iresnet18(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs)
    elif name == "r34":
        r = kwargs.pop("r", 4)
        scale = kwargs.pop("scale", 1)
        use_lora = kwargs.pop('use_lora', False)
        return iresnet34(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs)
    elif name == "r50":
        r = kwargs.pop("r", 4)
        scale = kwargs.pop("scale", 1)
        use_lora = kwargs.pop('use_lora', False)
        return iresnet50(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs)
    elif name == "r100":
        r = kwargs.pop("r", 4)
        scale = kwargs.pop("scale", 1)
        use_lora = kwargs.pop('use_lora', False)
        return iresnet100(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs)
    elif name == "r200":
        r = kwargs.pop("r", 4)
        scale = kwargs.pop("scale", 1)
        use_lora = kwargs.pop('use_lora', False)
        return iresnet200(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs)
    elif name == "r2060":
        from .iresnet2060 import iresnet2060
        return iresnet2060(False, **kwargs)

    elif name == "mbf":
        fp16 = kwargs.get("fp16", False)
        num_features = kwargs.get("num_features", 512)
        return get_mbf(fp16=fp16, num_features=num_features)

    elif name == "mbf_large":
        from .mobilefacenet import get_mbf_large
        fp16 = kwargs.get("fp16", False)
        num_features = kwargs.get("num_features", 512)
        return get_mbf_large(fp16=fp16, num_features=num_features)

    elif name == "vit_t":
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer
        return VisionTransformer(
            img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
            num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)

    elif name == "vit_t_dp005_mask0": # For WebFace42M
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer
        return VisionTransformer(
            img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
            num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)

    elif name == "vit_s":
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer
        return VisionTransformer(
            img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
            num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
    
    elif name == "vit_s_dp005_mask_0":  # For WebFace42M
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer
        return VisionTransformer(
            img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
            num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
    
    elif name == "vit_b":
        # this is a feature
        num_features = kwargs.get("num_features", 512)
        r = kwargs.pop("r", 4)
        scale = kwargs.pop("scale", 1)
        use_lora = kwargs.pop('use_lora', False)
        from .vit import VisionTransformer
        return VisionTransformer(
            lora_rank=r, lora_scale=scale, img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, 
            num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True, use_lora=use_lora)

    elif name == "vit_b_iqa":
        # this is a feature
        num_features = kwargs.get("num_features", 512)
        r = kwargs.pop("r", 4)
        scale = kwargs.pop("scale", 1)
        use_lora = kwargs.pop('use_lora', False)
        from .vit_iqa import VisionTransformer
        return VisionTransformer(
            lora_rank=r, lora_scale=scale, img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, 
            num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True, use_lora=use_lora)

    elif name == "vit_b_dp005_mask_005":  # For WebFace42M
        # this is a feature
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer
        return VisionTransformer(
            img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
            num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)

    elif name == "vit_l_dp005_mask_005":  # For WebFace42M
        # this is a feature
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer
        return VisionTransformer(
            img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24,
            num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
        
    elif name == "vit_h":  # For WebFace42M
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer
        return VisionTransformer(
            img_size=112, patch_size=9, num_classes=num_features, embed_dim=1024, depth=48,
            num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0, using_checkpoint=True)

    elif name=="swin_256new":
        num_features = kwargs.get("num_features", 512)
        r = kwargs.pop("r", 4)
        scale = kwargs.pop("scale", 1)
        use_lora = kwargs.pop('use_lora', False)
        from .swin_models import  SwinTransformer
        kwargs['reso']=120
        return  SwinTransformer(lora_rank=r, lora_scale=scale, img_size=120, patch_size=6, in_chans=3, num_classes=512,
                 embed_dim=384, depths=[2,18,2], num_heads=[ 8, 16,16],
                 window_size=5, use_lora=use_lora, **kwargs) 

    elif name=="swin_256new_iqa":
        num_features = kwargs.get("num_features", 512)
        r = kwargs.pop("r", 4)
        scale = kwargs.pop("scale", 1)
        use_lora = kwargs.pop('use_lora', False)
        from .swin_models_iqa import  SwinTransformer
        kwargs['reso']=120
        return  SwinTransformer(lora_rank=r, lora_scale=scale, img_size=120, patch_size=6, in_chans=3, num_classes=512,
                 embed_dim=384, depths=[2,18,2], num_heads=[ 8, 16,16],
                 window_size=5, use_lora=use_lora, **kwargs) 
    else:
        raise ValueError()

4. 最后,利用本模型来进行人脸识别:

(1)录入人脸数据
    def encode_face_dataset(self, image_paths, names):
        face_encodings = []
        for index, path in enumerate(tqdm(image_paths)):
            #---------------------------------------------------#
            #   打开人脸图片
            #---------------------------------------------------#
            # 打开图片并转换为RGB(去除Alpha通道)
            image = Image.open(path).convert('RGB')  # 关键:强制转为3通道RGB
            image = np.array(image, np.float32)  # 转为数组
            #---------------------------------------------------#
            #   对输入图像进行一个备份
            #---------------------------------------------------#
            old_image   = image.copy()
            #---------------------------------------------------#
            #   计算输入图片的高和宽
            #---------------------------------------------------#
            im_height, im_width, _ = np.shape(image)
            #---------------------------------------------------#
            #   计算scale,用于将获得的预测框转换成原图的高宽
            #---------------------------------------------------#
            scale = [
                np.shape(image)[1], np.shape(image)[0], np.shape(image)[1], np.shape(image)[0]
            ]
            scale_for_landmarks = [
                np.shape(image)[1], np.shape(image)[0], np.shape(image)[1], np.shape(image)[0],
                np.shape(image)[1], np.shape(image)[0], np.shape(image)[1], np.shape(image)[0],
                np.shape(image)[1], np.shape(image)[0]
            ]
            if self.letterbox_image:
                image = letterbox_image(image, [self.retinaface_input_shape[1], self.retinaface_input_shape[0]])
                anchors = self.anchors
            else:
                anchors = Anchors(self.cfg, image_size=(im_height, im_width)).get_anchors()

            #---------------------------------------------------#
            #   将处理完的图片传入Retinaface网络当中进行预测
            #---------------------------------------------------#
            with torch.no_grad():
                #-----------------------------------------------------------#
                #   图片预处理,归一化。
                #-----------------------------------------------------------#
                image = torch.from_numpy(preprocess_input(image).transpose(2, 0, 1)).unsqueeze(0).type(torch.FloatTensor)

                if self.cuda:
                    image               = image.cuda()
                    anchors             = anchors.cuda()

                loc, conf, landms = self.net(image)
                #-----------------------------------------------------------#
                #   对预测框进行解码
                #-----------------------------------------------------------#
                boxes   = decode(loc.data.squeeze(0), anchors, self.cfg['variance'])
                #-----------------------------------------------------------#
                #   获得预测结果的置信度
                #-----------------------------------------------------------#
                conf    = conf.data.squeeze(0)[:, 1:2]
                #-----------------------------------------------------------#
                #   对人脸关键点进行解码
                #-----------------------------------------------------------#
                landms  = decode_landm(landms.data.squeeze(0), anchors, self.cfg['variance'])

                #-----------------------------------------------------------#
                #   对人脸检测结果进行堆叠
                #-----------------------------------------------------------#
                boxes_conf_landms = torch.cat([boxes, conf, landms], -1)
                boxes_conf_landms = non_max_suppression(boxes_conf_landms, self.confidence)

                if len(boxes_conf_landms) <= 0:
                    print(names[index], ":未检测到人脸")
                    continue
                if self.letterbox_image:
                    boxes_conf_landms = retinaface_correct_boxes(boxes_conf_landms, \
                        np.array([self.retinaface_input_shape[0], self.retinaface_input_shape[1]]), np.array([im_height, im_width]))

            boxes_conf_landms[:, :4] = boxes_conf_landms[:, :4] * scale
            boxes_conf_landms[:, 5:] = boxes_conf_landms[:, 5:] * scale_for_landmarks

            #---------------------------------------------------#
            #   选取最大的人脸框。
            #---------------------------------------------------#
            best_face_location  = None
            biggest_area        = 0
            for result in boxes_conf_landms:
                left, top, right, bottom = result[0:4]

                w = right - left
                h = bottom - top
                if w * h > biggest_area:
                    biggest_area = w * h
                    best_face_location = result

            #---------------------------------------------------#
            #   截取图像
            #---------------------------------------------------#
            crop_img = old_image[int(best_face_location[1]):int(best_face_location[3]), int(best_face_location[0]):int(best_face_location[2])]
            landmark = np.reshape(best_face_location[5:],(5,2)) - np.array([int(best_face_location[0]),int(best_face_location[1])])
            crop_img,_ = Alignment_1(crop_img,landmark)

            crop_img = np.array(letterbox_image(np.uint8(crop_img),(self.facenet_input_shape[1],self.facenet_input_shape[0])))/255
            crop_img = crop_img.transpose(2, 0, 1)
            crop_img = np.expand_dims(crop_img,0)
            #---------------------------------------------------#
            #   利用图像算取长度为512的特征向量
            #---------------------------------------------------#
            with torch.no_grad():
                crop_img = torch.from_numpy(crop_img).type(torch.FloatTensor)
                if self.cuda:
                    crop_img = crop_img.cuda()

                face_encoding = self.petalface(crop_img)[0].cpu().numpy()
                face_encodings.append(face_encoding)

        np.save("{你的保存路径}.npy".format(backbone=self.facenet_backbone),face_encodings)
        np.save("{你的路径}.npy".format(backbone=self.facenet_backbone),names)

(2)之后可以进行人脸的信息比对,我使用的是余弦相似度进行比较的
#---------------------------------#
#   比较人脸(PETALface专用)
#---------------------------------#
def compare_faces_petalface(known_face_encodings, face_encoding_to_check, tolerance=0.5):
    """
    比较待检测人脸与已知人脸库的匹配情况
    Args:
        known_face_encodings: 已知人脸特征列表,shape为(n, d)
        face_encoding_to_check: 待检测人脸特征,shape为(d,)
        tolerance: 相似度阈值,大于等于该值认为匹配(默认0.5,可根据模型校准)
    Returns:
        匹配结果列表(布尔值)和对应的相似度数组
    """
    # 计算相似度
    similarities = face_similarity_petalface(known_face_encodings, face_encoding_to_check)
    # 根据阈值判断匹配结果
    return list(similarities >= tolerance), similarities
    
def face_similarity_petalface(face_encodings, face_to_compare, use_cosine=True):
	"""
	计算一组人脸特征与目标人脸特征的相似度
	Args:
	    face_encodings: 已知人脸特征列表,shape为(n, d),n为数量,d为特征维度
	    face_to_compare: 待比对人脸特征,shape为(d,)
	    use_cosine: 是否使用余弦相似度(推荐),否则使用原始内积
	Returns:
	    相似度数组,shape为(n,),值越大表示越相似
	"""
if len(face_encodings) == 0:
    return np.empty((0))

# 确保输入为numpy数组
face_encodings = np.array(face_encodings)
face_to_compare = np.array(face_to_compare)

if use_cosine:
    # 使用L2标准化计算余弦相似度
    if sklearn is not None:
        # 使用sklearn标准化
        face_encodings_norm = sklearn.preprocessing.normalize(face_encodings, norm='l2')
        face_to_compare_norm = sklearn.preprocessing.normalize(face_to_compare.reshape(1, -1), norm='l2')
        return np.dot(face_encodings_norm, face_to_compare_norm.T).flatten()
    else:
        # 手动L2标准化
        face_encodings_norm = face_encodings / (np.linalg.norm(face_encodings, axis=1, keepdims=True) + 1e-8)
        face_to_compare_norm = face_to_compare / (np.linalg.norm(face_to_compare) + 1e-8)
        return np.dot(face_encodings_norm, face_to_compare_norm.T)
else:
    # 原始内积相似度
    return np.dot(face_encodings, face_to_compare.T)

总体而言,PETALface 不仅在技术上突破了低分辨率人脸识别的核心瓶颈,其 “预训练骨干 + 参数高效微调 + 场景感知适配” 的设计思路,也为其他小样本、跨域计算机视觉任务提供了重要借鉴。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐