PETALface代码解读与人脸识别
PETALface 以 “参数高效迁移学习” 为核心,通过 backbones/lora_layers.py 实现 LoRA 模块,结合 train_iqa.py 中的图像质量加权策略,在低分辨率数据集上微调时仅更新少量参数,避免灾难性遗忘。数据集处理、模型训练、多场景验证的模块化设计,使其能高效支持高低分辨率人脸识别任务的研究与部署。参数名作用示例值--network骨干网络类型(需带IQA)-
PETALface代码链接:https://github.com/Kartik-3004/PETALface
PETALface论文链接:https://arxiv.org/abs/2412.07771
PETALface预训练权重链接:https://huggingface.co/kartiknarayan/PETALface
一、代码解读
1. 首先我们需要进行代码结构的解读
核心配置与入口文件
config.py: 项目核心配置文件,定义训练参数(如学习率、批处理大小)、数据集路径、模型超参数等,是整个项目的参数中枢。
train.py / train_iqa.py: 训练入口脚本。train.py 用于基础模型训练,train_iqa.py 专门针对集成图像质量评估(IQA)的模型训练,包含训练循环、损失计算、参数更新等逻辑。
losses.py: 实现损失函数,为模型训练提供分类损失计算支持。
lr_scheduler.py: 学习率调度器,控制训练过程中学习率的衰减策略,优化模型收敛效果。
骨干网络模块(backbones/)存放各类特征提取网络结构,支持不同模型架构的灵活切换,包含:
iresnet.py / iresnet2060.py: 基于残差网络(ResNet)的改进版本,适用于人脸特征提取。
swin_models.py / swin_models_iqa.py:Swin Transformer 模型实现,其中 swin_models_iqa.py 集成了图像质量评估相关的适配层。
vit.py / vit_iqa.py: 视觉 Transformer(ViT)模型,vit_iqa.py 针对低分辨率场景加入了质量感知模块。
lora_layers.py: 核心创新点实现,定义了 LoRA(Low-Rank Adaptation)低秩适配层,支持参数高效微调,仅更新少量参数即可适配低分辨率数据。
mobilefacenet.py: 轻量级 MobileFaceNet 模型,适合资源受限场景。
工具函数(utils/)
提供训练、评估、日志等辅助功能:
evaluate_utils.py: 评估工具函数。
utils_logging.py:日志管理工具,记录训练过程中的损失、精度等信息,支持可视化工具集成。
utils_distributed_sampler.py: 分布式训练采样器,支持多 GPU 环境下的数据分发。
utils_callbacks.py: 训练回调函数,如模型保存、早停策略等。
plot.py: 可视化工具,绘制训练曲线、特征分布等图表。
数据集处理(dataset/)
dataset.py: 数据集加载与预处理逻辑,支持读取 WebFace4M、TinyFace 等数据集,包含图像加载、增强、对齐等操作。
label_to_idx.py: 标签映射工具,将人脸身份标签转换为模型可识别的索引,处理多数据集的标签一致性。
头部网络(heads/)
partial_fc.py: 实现 Partial FC 头部,用于大规模人脸分类任务,优化类别不平衡问题,配合 margin 损失提升特征判别性。
验证模块(validation_*/)
针对不同数据集的验证逻辑,用于评估模型在高低分辨率场景下的性能:
validation_ijb/: 针对 IJB-B、IJB-C 等混合质量数据集的验证,包含 eval_ijb.py 等脚本,实现特征提取与比对。
validation_hq/: 高分辨率数据集(如 LFW、CFP)的验证,测试模型在高清场景下的泛化能力。
validation_lq/: 低分辨率数据集(如 TinyFace、BRIAR、IJB-S)的验证,核心文件包括:
PFE/ijbs.py: 解析 IJB-S 数据集协议,划分 gallery(图库)和 probe(探针),组织验证模板。
PFE/utils.py: 验证过程中的日志记录工具,支持 TensorBoard 等可视化 summary 写入。
脚本与文档(scripts/ / docs/)
scripts/: 存放训练、评估的脚本示例,简化命令行操作。
docs/: 项目文档,包括算法框架图、动机介绍、使用说明等,index.html 为项目网页,可视化展示核心思想与实验结果。
其他文件
environment.yml: conda 环境配置文件,定义项目依赖(如 PyTorch、OpenCV 等),一键搭建运行环境。
README.md: 项目说明文档,包含动机、贡献、安装步骤、数据集结构等关键信息。
LICENSE:MIT 许可证,明确代码的使用权限与限制。
整体逻辑总结
PETALface 以 “参数高效迁移学习” 为核心,通过 backbones/lora_layers.py 实现 LoRA 模块,结合 train_iqa.py 中的图像质量加权策略,在低分辨率数据集上微调时仅更新少量参数,避免灾难性遗忘。数据集处理、模型训练、多场景验证的模块化设计,使其能高效支持高低分辨率人脸识别任务的研究与部署。
2. 在清楚了代码结构之后,我们看一下这个项目是怎么进行训练的
由于我们的目标是低质量的人脸识别,所以主要分析的是train_iqa.py这个代码:
下面按「代码结构+核心功能+关键细节」的顺序,逐部分拆解解释:
一、整体代码结构
代码遵循「初始化→配置→数据加载→模型搭建→训练循环→保存验证」的标准深度学习训练流程,可分为 8 个核心模块:
- 依赖导入与分布式训练初始化
- 图像质量权重生成函数(
generate_alpha) - 主函数入口(
main):包含所有训练逻辑- 环境配置与日志初始化
- 数据加载(训练集)
- 模型搭建(骨干网络+LoRA+分类头)
- 优化器与学习率调度器初始化
- 训练循环(前向传播+反向传播+参数更新)
- 验证与模型保存
- 脚本入口(
if __name__ == "__main__")
二、逐模块详细解释
1. 依赖导入与分布式训练初始化
import argparse
import logging
import os
# ... 其他依赖导入(略)
assert torch.__version__ >= "1.12.0", "torch版本需≥1.12.0"
# 分布式训练参数(从环境变量读取,由torchrun或slurm启动时自动设置)
rank = int(os.environ["RANK"]) # 当前进程的全局排名(多GPU时区分不同GPU)
local_rank = int(os.environ["LOCAL_RANK"]) # 当前进程的本地排名(单节点内的GPU编号)
world_size = int(os.environ["WORLD_SIZE"]) # 总进程数(=总GPU数)
distributed.init_process_group("nccl") # 初始化分布式训练(使用NCCL通信后端,GPU间高效通信)
- 核心作用:搭建多GPU分布式训练环境,让多个GPU协同训练(提升速度、处理更大批量数据)。
- 关键依赖:
torch.distributed是PyTorch分布式训练的核心模块,nccl是NVIDIA的GPU通信协议,效率最高。
2. 核心工具函数:generate_alpha(图像质量权重生成)
def generate_alpha(img, iqa, thresh):
device = img.device # 获取图像所在设备(GPU)
BS, C, H, W = img.shape # 批量大小、通道数、高、宽
alpha = torch.zeros((BS, 1), dtype=torch.float32, device=device) # 初始化质量权重(每个样本1个权重)
score = iqa(img) # 调用IQA模型计算每张图像的质量分数(0~1,分数越高质量越好)
threshold = thresh # 质量阈值(由参数--threshold指定)
for i in range(BS):
if score[i] == threshold:
alpha[i] = 0.5 # 分数等于阈值时,权重为0.5
elif score[i] < threshold: # 低质量图像(分数<阈值)
alpha[i] = 0.5 - (threshold - score[i]) # 权重随分数降低而减小(最低接近0)
else: # 高质量图像(分数>阈值)
alpha[i] = 0.5 + (score[i] - threshold) # 权重随分数升高而增大(最高接近1)
return alpha
- 核心目标:将图像质量分数转化为「模型可理解的权重
alpha」,实现「质量感知的特征提取」。 - 逻辑本质:高质量图像给更高权重(模型更依赖其特征),低质量图像给更低权重(降低其对训练的干扰),最终缩小高低分辨率图像的域差异。
- 依赖:
iqa是通过pyiqa库加载的预训练质量评估模型(如BRISQUE、CNNIQA)。
3. 主函数 main:训练核心逻辑
(1)环境配置与日志初始化
def main(args):
setup_seed(seed=args.seed, cuda_deterministic=False) # 设置随机种子,保证实验可复现
torch.cuda.set_device(local_rank) # 将当前进程绑定到指定GPU(local_rank)
os.makedirs(args.output, exist_ok=True) # 创建模型保存目录
init_logging(rank, args.output) # 初始化日志(仅主进程rank=0打印日志,避免多GPU重复输出)
# TensorBoard/WandB可视化(仅主进程启用,用于监控训练损失、精度)
summary_writer = SummaryWriter(...) if rank == 0 else None
wandb_logger = wandb.init(...) if args.using_wandb and rank == 0 else None
- 关键细节:
setup_seed固定CPU/GPU的随机种子,确保每次训练结果一致;rank==0是主进程,负责日志、可视化、模型保存等核心操作,其他进程仅负责计算。
(2)数据加载
train_loader = get_dataloader(
args.rec, # 数据集路径
local_rank, # 本地GPU编号(分布式采样用)
args.batch_size, # 单GPU批次大小
args.image_size, # 图像尺寸(如120×120,低分辨率数据)
args.dali, # 是否使用DALI加速数据加载(NVIDIA的高效数据加载库)
# ... 其他数据增强、workers参数
)
- 核心作用:读取低分辨率人脸数据集(如TinyFace、BRIAR),并进行预处理(缩放、归一化、数据增强),返回批量数据迭代器。
- 分布式适配:
get_dataloader内部会创建分布式采样器,确保多GPU间数据不重复、不遗漏。
(3)模型搭建:骨干网络 + LoRA + 分类头
# 1. 搭建骨干网络(支持LoRA和IQA适配,如Swin Transformer/IQA版本)
backbone = get_model(
args.network, # 骨干网络类型(如swin_256new_iqa)
dropout=0.0,
fp16=args.fp16, # 是否启用混合精度训练(加速训练、减少显存占用)
num_features=args.embedding_size, # 特征向量维度(默认1024)
r=args.lora_rank, # LoRA低秩矩阵的秩(如8)
use_lora=args.use_lora # 是否启用LoRA
).cuda()
# 2. 包装为分布式数据并行(DDP),支持多GPU训练
backbone = torch.nn.parallel.DistributedDataParallel(
module=backbone,
device_ids=[local_rank],
find_unused_parameters=True # 允许部分参数不参与计算(如LoRA冻结其他层时)
)
backbone.register_comm_hook(None, fp16_compress_hook) # FP16通信压缩,加速GPU间数据传输
# 3. 搭建分类头(Partial FC)和损失函数(CombinedMarginLoss=ArcFace/CosFace)
margin_loss = CombinedMarginLoss(...) # 带margin的损失函数,增强特征判别性
head = get_head(
args.head, # 分类头类型(默认partial_fc)
margin_loss=margin_loss,
embedding_size=args.embedding_size, # 输入特征维度(1024)
num_classes=args.num_classes # 数据集类别数(如TinyFace为2570)
).cuda()
- 核心创新点:骨干网络支持接收
alpha质量权重(backbone(img, alpha)),结合LoRA层实现「质量自适应的参数高效微调」。 - Partial FC:针对大规模数据集(如WebFace4M)的高效分类头,避免全连接层参数爆炸(如百万类别时全连接层参数过大)。
(4)LoRA微调配置(核心参数高效逻辑)
if args.use_lora:
# 加载高分辨率预训练模型权重
weights_path = os.path.join(args.load_pretrained, f"checkpoint_gpu_{rank}.pt")
if os.path.exists(weights_path):
dict_checkpoint = torch.load(weights_path)
backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"], strict=False)
else:
dict_checkpoint = torch.load(os.path.join(args.load_pretrained, "model.pt"))
backbone.module.load_state_dict(dict_checkpoint, strict=False)
# 冻结骨干网络大部分参数(仅LoRA层可训练)
for p in backbone.parameters():
p.requires_grad = False # 冻结所有参数
for name, p in backbone.named_parameters():
if 'trainable_lora' in name: # 仅LoRA相关参数解冻(可训练)
p.requires_grad = True
# 分类头参数全部可训练
for p in head.parameters():
p.requires_grad = True
- 核心逻辑:LoRA微调的核心是「冻结预训练骨干网络,仅训练少量低秩矩阵参数」,避免灾难性遗忘(忘记高分辨率图像的特征提取能力),同时适配低分辨率数据。
- 参数占比:仅训练约0.48%的总参数(LoRA层的A、B矩阵),大幅降低计算成本和过拟合风险。
(5)优化器与学习率调度器
# 选择优化器(SGD或AdamW),仅优化可训练参数(LoRA层+分类头)
if args.optimizer == "adamw":
opt = torch.optim.AdamW(
params=[
{"params": filter(lambda p: p.requires_grad, backbone.parameters())}, # LoRA可训练参数
{"params": head.parameters()} # 分类头参数
],
lr=args.lr, # 学习率(微调时通常较小,如0.0005)
weight_decay=args.weight_decay # 权重衰减,防止过拟合
)
# 学习率调度器(预热+多项式衰减)
args.warmup_step = args.num_image // args.total_batch_size * args.warmup_epoch # 预热步数
args.total_step = args.num_image // args.total_batch_size * args.num_epoch # 总步数
lr_scheduler = PolynomialLRWarmup(
optimizer=opt,
warmup_iters=args.warmup_step, # 预热阶段(学习率从0升到初始值)
total_iters=args.total_step # 总阶段(学习率多项式衰减)
)
- 关键设计:预热阶段避免初始高学习率破坏预训练权重,多项式衰减让训练后期学习率逐步降低,稳定收敛。
(6)训练循环(核心!前向→反向→更新)
# 加载IQA模型(BRISQUE或CNNIQA)
if args.iqa == "brisque":
iqa = pyiqa.create_metric('brisque').cuda()
threshold = args.threshold # 质量阈值(由用户指定,如0.3)
# 迭代训练轮次
for epoch in range(start_epoch, args.num_epoch):
# 分布式采样器更新epoch(确保每轮数据打乱)
if isinstance(train_loader, DataLoader):
train_loader.sampler.set_epoch(epoch)
# 迭代每个批次
for _, (img, local_labels) in enumerate(train_loader):
global_step += 1
# 1. 计算图像质量权重alpha(核心!IQA融入训练)
alpha = generate_alpha(img.clone(), iqa, threshold) # img.clone()避免修改原始图像
# 2. 前向传播:图像+质量权重 → 特征 → 损失
local_embeddings = backbone(img, alpha) # 骨干网络输出1024维特征(融入质量权重)
loss = head(local_embeddings, local_labels) # 分类头计算损失(CombinedMarginLoss)
# 3. 反向传播与参数更新
if args.fp16: # 混合精度训练(加速+省显存)
amp.scale(loss).backward() # 损失缩放,避免梯度下溢
if global_step % args.gradient_acc == 0: # 梯度累积(模拟更大批次)
amp.unscale_(opt)
torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) # 梯度裁剪,防止梯度爆炸
amp.step(opt) # 更新参数
amp.update()
opt.zero_grad() # 清空梯度
else: # 普通精度训练
loss.backward()
if global_step % args.gradient_acc == 0:
torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
opt.step()
opt.zero_grad()
# 4. 学习率更新
lr_scheduler.step()
# 5. 日志记录与可视化(损失、学习率)
with torch.no_grad():
loss_am.update(loss.item(), 1) # 平均损失统计
callback_logging(...) # 打印训练日志
if wandb_logger:
wandb_logger.log({"Loss": loss.item()}) # WandB可视化
# 6. 定期验证(评估模型在验证集上的性能)
if global_step % args.verbose == 0 and global_step > 0:
callback_verification(global_step, backbone) # 验证LFW/IJB-S等数据集
- 核心流程拆解:
① 对每个批次图像,先计算质量权重alpha;
② 骨干网络接收「图像+alpha」,输出质量自适应的特征;
③ 分类头用Margin损失优化特征判别性;
④ 仅更新LoRA层和分类头的参数,冻结其他骨干网络参数;
⑤ 定期验证模型性能,确保训练不跑偏。
(7)模型保存与分布式训练收尾
# 保存检查点(所有GPU的参数,用于断点续训)
if args.save_all_states:
checkpoint = {
"epoch": epoch + 1,
"state_dict_backbone": backbone.module.state_dict(), # 骨干网络参数
"state_dict_softmax_fc": head.state_dict(), # 分类头参数
"state_optimizer": opt.state_dict(), # 优化器状态
"state_lr_scheduler": lr_scheduler.state_dict() # 调度器状态
}
torch.save(checkpoint, os.path.join(args.output, f"checkpoint_gpu_{rank}.pt"))
# 主进程保存最终模型(供推理/验证使用)
if rank == 0:
torch.save(backbone.module.state_dict(), os.path.join(args.output, "model.pt"))
if wandb_logger:
wandb_logger.log_artifact(...) # WandB保存模型 artifacts
# 分布式训练收尾(销毁进程组)
torch.distributed.barrier() # 等待所有GPU完成
destroy_process_group()
- 关键细节:
backbone.module是获取DDP包装后的原始骨干网络(DDP会对模型进行封装,需通过.module访问原始模型)。
三、核心亮点与项目创新的对应
这段代码完美体现了PETALface的两大核心创新:
- 参数高效微调(LoRA):通过冻结骨干网络、仅训练LoRA层,解决低样本低分辨率数据的灾难性遗忘问题;
- 质量感知加权(IQA+alpha):通过IQA模型计算质量分数,动态调整LoRA权重,缩小高低分辨率图像的域差异。
四、关键参数总结(从命令行/配置文件传入)
| 参数名 | 作用 | 示例值 |
|---|---|---|
--network |
骨干网络类型(需带IQA) | swin_256new_iqa |
--use_lora |
是否启用LoRA微调 | True |
--lora_rank |
LoRA低秩矩阵的秩 | 8 |
--load_pretrained |
预训练模型路径 | ./pretrained |
--iqa |
图像质量评估模型 | brisque |
--threshold |
质量分数阈值(控制alpha权重范围) | 0.3 |
--num_classes |
数据集类别数(如TinyFace为2570) | 2570 |
--batch_size |
单GPU批次大小 | 8 |
--lr |
学习率(微调时较小) | 0.0005 |
通过调整这些参数,可适配不同的低分辨率数据集和训练需求。
3. 看一下模型的网络结构
PETALface 支持多种成熟的人脸识别骨干网络作为基础模型,例如IResNet ,mobilefacenet等;具体的模型列表可以在backbones目录中的__init__.py里面找到;
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
from .mobilefacenet import get_mbf
def get_model(name, **kwargs):
# resnet
if name == "r18":
r = kwargs.pop("r", 4)
scale = kwargs.pop("scale", 1)
use_lora = kwargs.pop('use_lora', False)
return iresnet18(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs)
elif name == "r34":
r = kwargs.pop("r", 4)
scale = kwargs.pop("scale", 1)
use_lora = kwargs.pop('use_lora', False)
return iresnet34(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs)
elif name == "r50":
r = kwargs.pop("r", 4)
scale = kwargs.pop("scale", 1)
use_lora = kwargs.pop('use_lora', False)
return iresnet50(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs)
elif name == "r100":
r = kwargs.pop("r", 4)
scale = kwargs.pop("scale", 1)
use_lora = kwargs.pop('use_lora', False)
return iresnet100(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs)
elif name == "r200":
r = kwargs.pop("r", 4)
scale = kwargs.pop("scale", 1)
use_lora = kwargs.pop('use_lora', False)
return iresnet200(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs)
elif name == "r2060":
from .iresnet2060 import iresnet2060
return iresnet2060(False, **kwargs)
elif name == "mbf":
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf(fp16=fp16, num_features=num_features)
elif name == "mbf_large":
from .mobilefacenet import get_mbf_large
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf_large(fp16=fp16, num_features=num_features)
elif name == "vit_t":
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
elif name == "vit_t_dp005_mask0": # For WebFace42M
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
elif name == "vit_s":
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
elif name == "vit_s_dp005_mask_0": # For WebFace42M
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
elif name == "vit_b":
# this is a feature
num_features = kwargs.get("num_features", 512)
r = kwargs.pop("r", 4)
scale = kwargs.pop("scale", 1)
use_lora = kwargs.pop('use_lora', False)
from .vit import VisionTransformer
return VisionTransformer(
lora_rank=r, lora_scale=scale, img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True, use_lora=use_lora)
elif name == "vit_b_iqa":
# this is a feature
num_features = kwargs.get("num_features", 512)
r = kwargs.pop("r", 4)
scale = kwargs.pop("scale", 1)
use_lora = kwargs.pop('use_lora', False)
from .vit_iqa import VisionTransformer
return VisionTransformer(
lora_rank=r, lora_scale=scale, img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True, use_lora=use_lora)
elif name == "vit_b_dp005_mask_005": # For WebFace42M
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
elif name == "vit_l_dp005_mask_005": # For WebFace42M
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
elif name == "vit_h": # For WebFace42M
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=1024, depth=48,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0, using_checkpoint=True)
elif name=="swin_256new":
num_features = kwargs.get("num_features", 512)
r = kwargs.pop("r", 4)
scale = kwargs.pop("scale", 1)
use_lora = kwargs.pop('use_lora', False)
from .swin_models import SwinTransformer
kwargs['reso']=120
return SwinTransformer(lora_rank=r, lora_scale=scale, img_size=120, patch_size=6, in_chans=3, num_classes=512,
embed_dim=384, depths=[2,18,2], num_heads=[ 8, 16,16],
window_size=5, use_lora=use_lora, **kwargs)
elif name=="swin_256new_iqa":
num_features = kwargs.get("num_features", 512)
r = kwargs.pop("r", 4)
scale = kwargs.pop("scale", 1)
use_lora = kwargs.pop('use_lora', False)
from .swin_models_iqa import SwinTransformer
kwargs['reso']=120
return SwinTransformer(lora_rank=r, lora_scale=scale, img_size=120, patch_size=6, in_chans=3, num_classes=512,
embed_dim=384, depths=[2,18,2], num_heads=[ 8, 16,16],
window_size=5, use_lora=use_lora, **kwargs)
else:
raise ValueError()
4. 最后,利用本模型来进行人脸识别:
(1)录入人脸数据
def encode_face_dataset(self, image_paths, names):
face_encodings = []
for index, path in enumerate(tqdm(image_paths)):
#---------------------------------------------------#
# 打开人脸图片
#---------------------------------------------------#
# 打开图片并转换为RGB(去除Alpha通道)
image = Image.open(path).convert('RGB') # 关键:强制转为3通道RGB
image = np.array(image, np.float32) # 转为数组
#---------------------------------------------------#
# 对输入图像进行一个备份
#---------------------------------------------------#
old_image = image.copy()
#---------------------------------------------------#
# 计算输入图片的高和宽
#---------------------------------------------------#
im_height, im_width, _ = np.shape(image)
#---------------------------------------------------#
# 计算scale,用于将获得的预测框转换成原图的高宽
#---------------------------------------------------#
scale = [
np.shape(image)[1], np.shape(image)[0], np.shape(image)[1], np.shape(image)[0]
]
scale_for_landmarks = [
np.shape(image)[1], np.shape(image)[0], np.shape(image)[1], np.shape(image)[0],
np.shape(image)[1], np.shape(image)[0], np.shape(image)[1], np.shape(image)[0],
np.shape(image)[1], np.shape(image)[0]
]
if self.letterbox_image:
image = letterbox_image(image, [self.retinaface_input_shape[1], self.retinaface_input_shape[0]])
anchors = self.anchors
else:
anchors = Anchors(self.cfg, image_size=(im_height, im_width)).get_anchors()
#---------------------------------------------------#
# 将处理完的图片传入Retinaface网络当中进行预测
#---------------------------------------------------#
with torch.no_grad():
#-----------------------------------------------------------#
# 图片预处理,归一化。
#-----------------------------------------------------------#
image = torch.from_numpy(preprocess_input(image).transpose(2, 0, 1)).unsqueeze(0).type(torch.FloatTensor)
if self.cuda:
image = image.cuda()
anchors = anchors.cuda()
loc, conf, landms = self.net(image)
#-----------------------------------------------------------#
# 对预测框进行解码
#-----------------------------------------------------------#
boxes = decode(loc.data.squeeze(0), anchors, self.cfg['variance'])
#-----------------------------------------------------------#
# 获得预测结果的置信度
#-----------------------------------------------------------#
conf = conf.data.squeeze(0)[:, 1:2]
#-----------------------------------------------------------#
# 对人脸关键点进行解码
#-----------------------------------------------------------#
landms = decode_landm(landms.data.squeeze(0), anchors, self.cfg['variance'])
#-----------------------------------------------------------#
# 对人脸检测结果进行堆叠
#-----------------------------------------------------------#
boxes_conf_landms = torch.cat([boxes, conf, landms], -1)
boxes_conf_landms = non_max_suppression(boxes_conf_landms, self.confidence)
if len(boxes_conf_landms) <= 0:
print(names[index], ":未检测到人脸")
continue
if self.letterbox_image:
boxes_conf_landms = retinaface_correct_boxes(boxes_conf_landms, \
np.array([self.retinaface_input_shape[0], self.retinaface_input_shape[1]]), np.array([im_height, im_width]))
boxes_conf_landms[:, :4] = boxes_conf_landms[:, :4] * scale
boxes_conf_landms[:, 5:] = boxes_conf_landms[:, 5:] * scale_for_landmarks
#---------------------------------------------------#
# 选取最大的人脸框。
#---------------------------------------------------#
best_face_location = None
biggest_area = 0
for result in boxes_conf_landms:
left, top, right, bottom = result[0:4]
w = right - left
h = bottom - top
if w * h > biggest_area:
biggest_area = w * h
best_face_location = result
#---------------------------------------------------#
# 截取图像
#---------------------------------------------------#
crop_img = old_image[int(best_face_location[1]):int(best_face_location[3]), int(best_face_location[0]):int(best_face_location[2])]
landmark = np.reshape(best_face_location[5:],(5,2)) - np.array([int(best_face_location[0]),int(best_face_location[1])])
crop_img,_ = Alignment_1(crop_img,landmark)
crop_img = np.array(letterbox_image(np.uint8(crop_img),(self.facenet_input_shape[1],self.facenet_input_shape[0])))/255
crop_img = crop_img.transpose(2, 0, 1)
crop_img = np.expand_dims(crop_img,0)
#---------------------------------------------------#
# 利用图像算取长度为512的特征向量
#---------------------------------------------------#
with torch.no_grad():
crop_img = torch.from_numpy(crop_img).type(torch.FloatTensor)
if self.cuda:
crop_img = crop_img.cuda()
face_encoding = self.petalface(crop_img)[0].cpu().numpy()
face_encodings.append(face_encoding)
np.save("{你的保存路径}.npy".format(backbone=self.facenet_backbone),face_encodings)
np.save("{你的路径}.npy".format(backbone=self.facenet_backbone),names)
(2)之后可以进行人脸的信息比对,我使用的是余弦相似度进行比较的
#---------------------------------#
# 比较人脸(PETALface专用)
#---------------------------------#
def compare_faces_petalface(known_face_encodings, face_encoding_to_check, tolerance=0.5):
"""
比较待检测人脸与已知人脸库的匹配情况
Args:
known_face_encodings: 已知人脸特征列表,shape为(n, d)
face_encoding_to_check: 待检测人脸特征,shape为(d,)
tolerance: 相似度阈值,大于等于该值认为匹配(默认0.5,可根据模型校准)
Returns:
匹配结果列表(布尔值)和对应的相似度数组
"""
# 计算相似度
similarities = face_similarity_petalface(known_face_encodings, face_encoding_to_check)
# 根据阈值判断匹配结果
return list(similarities >= tolerance), similarities
def face_similarity_petalface(face_encodings, face_to_compare, use_cosine=True):
"""
计算一组人脸特征与目标人脸特征的相似度
Args:
face_encodings: 已知人脸特征列表,shape为(n, d),n为数量,d为特征维度
face_to_compare: 待比对人脸特征,shape为(d,)
use_cosine: 是否使用余弦相似度(推荐),否则使用原始内积
Returns:
相似度数组,shape为(n,),值越大表示越相似
"""
if len(face_encodings) == 0:
return np.empty((0))
# 确保输入为numpy数组
face_encodings = np.array(face_encodings)
face_to_compare = np.array(face_to_compare)
if use_cosine:
# 使用L2标准化计算余弦相似度
if sklearn is not None:
# 使用sklearn标准化
face_encodings_norm = sklearn.preprocessing.normalize(face_encodings, norm='l2')
face_to_compare_norm = sklearn.preprocessing.normalize(face_to_compare.reshape(1, -1), norm='l2')
return np.dot(face_encodings_norm, face_to_compare_norm.T).flatten()
else:
# 手动L2标准化
face_encodings_norm = face_encodings / (np.linalg.norm(face_encodings, axis=1, keepdims=True) + 1e-8)
face_to_compare_norm = face_to_compare / (np.linalg.norm(face_to_compare) + 1e-8)
return np.dot(face_encodings_norm, face_to_compare_norm.T)
else:
# 原始内积相似度
return np.dot(face_encodings, face_to_compare.T)
总体而言,PETALface 不仅在技术上突破了低分辨率人脸识别的核心瓶颈,其 “预训练骨干 + 参数高效微调 + 场景感知适配” 的设计思路,也为其他小样本、跨域计算机视觉任务提供了重要借鉴。
更多推荐



所有评论(0)