《第五十一篇》DETR 技术详解(Detection Transformer)
DETR(Detection Transformer)是Facebook AI提出的端到端目标检测模型,首次将Transformer引入目标检测任务。它摒弃了传统anchor和NMS机制,直接通过Transformer的Set Prediction框架预测边界框。模型结构包含:主干网络(如ResNet-50)提取特征、位置编码、Transformer编码器建模全局特征、解码器生成object qu
DETR 技术详解(Detection Transformer)
一、前言
DETR(DEtection TRansformer)是 Facebook AI 在 2020 年提出的一种端到端目标检测模型,它首次将 Transformer 结构引入目标检测任务中,取代了传统的 anchor-based 和 NMS 后处理机制。
它的核心思想是:
“使用 Transformer 的 Set Prediction 框架,直接预测一组 bounding boxes,并通过匈牙利匹配机制进行 loss 计算。”
本文将围绕 DETR 的模型结构、训练流程、损失函数、推理过程等进行详细讲解。
二、DETR 的完整模型结构流程图(输入图像:800×800×3)
Input Image (800x800x3)
│
├— Backbone: ResNet-50 / Swin Transformer → 提取多尺度特征 P3/P4/P5
│ ├— 输出 feature map `[B, C, H, W]`(如 `[1, 2048, 25, 25]`)
│
├— Neck: Feature Pyramid Network(FPN)→ 可选模块(部分变体启用)
│ ├— 上采样 + Concatenate(增强小目标识别能力)
│
├— Positional Encoding → 添加位置信息给 feature map
│
├— Transformer Encoder → 自注意力建模全局特征
│
├— Transformer Decoder + Learnable Queries → 解码器生成 object queries
│
└— Detection Head:
├— Bounding Box Reg Branch(回归 `(x_center, y_center, width, height)`)
└— Class Confidence Branch(分类置信度)
三、DETR 的完整模型结构详解
1. 主干网络(Backbone)
- 使用标准的 CNN 或 Vision Transformer 提取图像特征;
- 常见 backbone:
ResNet-50
(默认)ResNet-101
Swin-Tiny
/Swin-Base
(DETR-DC5 / Deformable DETR 中使用)
输出为 feature map:[B, C, H, W]
例如:[1, 2048, 25, 25]
(ResNet-50 输出)
2. 特征编码(Positional Encoding)
- 将 feature map 展平为
[B, C, HW]
; - 添加可学习的位置编码(positional encoding);
- 输入给 Transformer encoder;
pos_encoding = PositionEmbeddingSine()
flatten_feature_map = feature_map.flatten(2).permute(2, 0, 1) # [HW, B, C]
input_with_pos = flatten_feature_map + pos_encoding(flatten_feature_map)
3. Transformer Encoder
- 标准 Transformer 编码器;
- 对 feature map 进行全局自注意力建模;
- 输出为
[HW, B, C]
形式的编码后特征;
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
def forward(self, src):
for layer in self.layers:
src = layer(src)
return src
4. Learnable Object Queries(解码器输入)
- 初始化为 learnable embeddings;
- 数量通常设为 100(支持最多 100 个 objects);
- 作为 decoder 的初始输入;
query_embed = nn.Embedding(num_queries=100, embedding_dim=d_model)
queries = query_embed.weight.unsqueeze(1).repeat(1, batch_size, 1)
5. Transformer Decoder
- 使用 cross-attention 查询 encoder 输出;
- 输出为
[Q, B, D]
形式的 object embeddings; - Q 为 object queries 数量(如 100);
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
def forward(self, tgt, memory, ...):
for layer in self.layers:
tgt = layer(tgt, memory, ...)
return tgt
6. Detection Head(边界框 + 分类分支)
- 每个 object query 由两个 head 预测:
reg_head
: 回归 bounding box;cls_head
: 分类 confidence;
class BBoxHead(nn.Module):
def __init__(self, d_model, num_classes=91):
super().__init__()
self.bbox_embed = MLP(d_model, d_model, 4, 3) # 回归头
self.class_embed = nn.Linear(d_model, num_classes) # 分类头
def forward(self, outputs):
"""
outputs: [Q, B, D] ← Transformer Decoder 输出
"""
outputs_class = self.class_embed(outputs) # [Q, B, 91]
outputs_coord = self.bbox_embed(outputs).sigmoid() # [Q, B, 4]
return outputs_class, outputs_coord
四、DETR 的完整训练流程详解
Step 1: 数据预处理
git clone https://github.com/facebookresearch/detr
cd detr
pip install -e .
加载 COCO 数据集并进行归一化:
from datasets import build_dataset
dataset = build_dataset(image_set='train', args=args)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
Step 2: 图像编码器提取特征
features = backbone(image) # 输出 feature map
- 使用 ResNet-50 提取特征;
- 输出为
[B, C, H, W]
; - 可替换为其他 backbone(如 Swin Transformer);
Step 3: 添加位置编码(Positional Encoding)
pos = position_embedding(features)
src = features.flatten(2).permute(2, 0, 1) # [HW, B, C]
- 将 feature map 展平并添加位置信息;
- 输入给 Transformer encoder;
Step 4: Transformer Encoder 编码
memory = transformer.encoder(src, pos=pos)
- 使用自注意力建模全局特征;
- 输出为
[HW, B, C]
;
Step 5: Transformer Decoder 解码
output = transformer.decoder(query_embed, memory)
- 使用 learnable queries 查询 encoder 输出;
- 输出为
[Q, B, D]
,其中 Q=100; - 每个 query 表示一个 object;
Step 6: 边界框与分类预测
class_logits = class_embed(output) # [Q, B, 91]
bounding_boxes = bbox_embed(output).sigmoid() # [Q, B, 4]
- 每个 query 经过 head 输出 bounding box 和 class probs;
- 最终输出为
[Q, 4]
和[Q, 91]
;
Step 7: 匈牙利匹配(Hungarian Matching)
DETR 不使用传统 anchor 匹配方式,而是使用 匈牙利算法 找出 predictions 与 GT 的最佳匹配。
cost_matrix = compute_cost(pred_boxes, gt_boxes, pred_logits, gt_labels)
matched_indices = linear_sum_assignment(cost_matrix)
- 构建 cost matrix(分类误差 + IoU);
- 使用 scipy.optimize.linear_sum_assignment;
- 多对多匹配,避免重复预测;
Step 8: 损失函数计算(Hungarian Loss)
DETR 使用组合损失函数:
L t o t a l = λ c l s ⋅ L c l s ( p r e d _ c l s , g t _ c l s ) + λ l 1 ⋅ L l 1 ( p r e d _ b b o x , g t _ b b o x ) + λ g i o u ⋅ L g i o u ( p r e d _ b b o x , g t _ b b o x ) \mathcal{L}_{total} = \lambda_{cls} \cdot \mathcal{L}_{cls}(pred\_cls, gt\_cls) + \lambda_{l1} \cdot \mathcal{L}_{l1}(pred\_bbox, gt\_bbox) + \lambda_{giou} \cdot \mathcal{L}_{giou}(pred\_bbox, gt\_bbox) Ltotal=λcls⋅Lcls(pred_cls,gt_cls)+λl1⋅Ll1(pred_bbox,gt_bbox)+λgiou⋅Lgiou(pred_bbox,gt_bbox)
- 分类损失:BCEWithLogitsLoss;
- 回归损失:L1Loss + GIoULoss;
- 总 loss 为加权求和;
五、DETR 的完整推理流程详解
Step 1: 图像输入与预处理
python main.py --image test.jpg --model detr-resnet-50
内部执行流程如下:
image = cv2.imread("test.jpg")
resized_image = cv2.resize(image, (800, 800)) / 255.0
input_tensor = np.expand_dims(resized_image, axis=0) # 添加 batch 维度
Step 2: 推理输出(PyTorch)
outputs = model(input_tensor) # 输出 Q 个 bounding boxes
输出示例(简化表示):
{
"pred_logits": [100, 91], # 类别置信度
"pred_boxes": [100, 4], # (x_center, y_center, width, height)
}
Step 3: 解码 bounding box(Anchor-Free)
YOLOv10 与 DETR 均采用 Anchor-Free 模式:
def decode_box(output_tensor):
bboxes = output_tensor["pred_boxes"] # [100, 4]
scores = output_tensor["pred_logits"].softmax(dim=-1)[:, :-1].max(-1).values # 去掉 no-object 类
labels = output_tensor["pred_logits"].softmax(dim=-1)[:, :-1].argmax(-1) # 获取类别编号
# 应用阈值过滤低分框
keep_indices = scores > 0.7
final_bboxes = bboxes[keep_indices]
final_scores = scores[keep_indices]
final_labels = labels[keep_indices]
return final_bboxes, final_scores, final_labels
Step 4: 推理后处理(无需 NMS)
由于 DETR 是 set prediction 框架,每个预测框独立且互斥,因此不需要 NMS 后处理。
六、DETR 的完整改进点汇总表
改进方向 | 内容 |
---|---|
端到端设计 | 不需要 anchor 和 NMS |
使用 Transformer | 替代传统 CNN 特征金字塔 |
Hungarian Matching | 动态选择最优匹配 |
L1 + GIoU Loss | 提升定位精度 |
多任务统一接口 | detect / segment / classify(后续版本支持) |
七、DETR 的完整模型变体支持
模型版本 | 主干网络 | 是否支持部署 |
---|---|---|
DETR | ResNet-50 | 是 |
DETR-DC5 | ResNet-50-dc5 | 是 |
DETR-Swin | Swin-Tiny / Base | 是 |
Conditional DETR | 改进版解码器 | 是 |
DAB-DETR | Dynamic Anchor Boxes | 是 |
Deformable DETR | 可变形 attention | 是 |
八、DETR 的完整训练 & 推理流程总结
训练流程:
DataLoader → Mosaic/CopyPaste → ResNet-50 → FPN(可选) → Positional Encoding → Transformer Encoder → Transformer Decoder → Hungarian Matching → Loss Calculation (L1 + GIoU + BCE) → Backpropagation
推理流程:
Image → Preprocess → ResNet-50 → Flatten → Positional Encoding → Transformer Encoder → Transformer Decoder → Decode Boxes → Output Final Detections(无 NMS)
九、DETR 的完整改进点对比表
改进点 | 内容 | 是否首次提出 | 是否开源实现 |
---|---|---|---|
无 anchor 设计 | 不依赖预设 anchor | 是 | 是 |
无 NMS 后处理 | 推理阶段不使用 NMS | 是 | 是 |
Hungarian Matching | 成本矩阵 + 匹配机制 | 是 | 是 |
L1 + GIoU Loss | 边界框回归 | 是 | 是 |
多任务统一接口 | detect / segment / classify | 是(后续版本支持) | 是 |
支持 ONNX 导出 | 可转换为 ONNX / TensorRT | 是(需手动导出) | 社区已有尝试 |
十、DETR 的完整性能表现
模型版本 | mAP@COCO val | FPS(V100) | 参数数量 |
---|---|---|---|
DETR-ResNet50 | ~42.0% | ~25 | ~40M |
DETR-ResNet101 | ~43.5% | ~20 | ~60M |
Conditional DETR | ~44.2% | ~20 | ~40M |
DAB-DETR | ~45.1% | ~18 | ~40M |
Deformable DETR | ~46.6% | ~30 | ~40M |
十一、DETR 的完整训练过程模拟代码(简化版)
import torch
from models.detr import build_detr
from datasets.coco import build_coco
# Step 1: 加载数据集
dataset_train = build_coco(image_set='train', args=args)
data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=16, shuffle=True)
# Step 2: 初始化模型
model, criterion, postprocessors = build_detr(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Step 3: 开始训练
for samples, targets in data_loader_train:
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
outputs = model(samples)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys())
losses.backward()
optimizer.step()
optimizer.zero_grad()
十二、DETR 的完整推理流程模拟代码(输入一张图像)
from PIL import Image
import requests
import torchvision.transforms as T
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def load_image(url):
image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
image = transform(image)
return image.unsqueeze(0) # 添加 batch 维度
image = load_image("https://example.com/test.jpg")
# 加载预训练模型
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval()
# 推理
with torch.no_grad():
outputs = model(image)
logits = outputs['pred_logits']
bboxes = outputs['pred_boxes']
# 解码输出
scores = logits.softmax(-1)[..., :-1].max(-1).values
labels = logits.softmax(-1)[..., :-1].argmax(-1)
bboxes = bboxes * image.shape[-2:] # 转换为像素坐标
# 过滤低分框
keep_indices = scores > 0.7
final_bboxes = bboxes[keep_indices]
final_scores = scores[keep_indices]
final_labels = labels[keep_indices]
print(final_bboxes)
十三、DETR 的完整训练 & 推理流程总结
阶段 | 内容 |
---|---|
预处理 | 图像归一化 + Tokenization |
视觉编码 | ResNet-50 提取 patch embeddings |
位置编码 | 添加 positional encoding |
Transformer 编码器 | 自注意力建模全局关系 |
Transformer 解码器 | 使用 queries 查询目标 |
语言解码 | 分类 + 边界框回归 |
损失函数 | Hungarian Loss(L1 + GIoU + BCE) |
推理输出 | 不使用 NMS,直接输出 top-k 框 |
十四、DETR 的完整训练 & 推理流程总结
训练流程:
DataLoader → 图像 + GT 框 → ResNet-50 → Patch Embeddings → Transformer Encoder → Transformer Decoder → Hungarian Matching → Loss Calculation(L1 + GIoU + BCE) → Backpropagation
推理流程:
Image → Preprocess → ResNet-50 → Patch Embeddings → Transformer Encoder → Transformer Decoder → Hungarian Matching → Decode → Final Bounding Boxes(无 NMS)
十五、DETR 的完整模型结构可视化方式
你可以通过以下方式查看 DETR 的模型结构:
方法一:使用 Netron 查看 ONNX 模型结构
# 导出模型(需手动实现导出脚本)
torch.onnx.export(model, input_data, "detr.onnx", export_params=True, opset_version=13)
# 使用在线工具打开 .onnx 文件
# 地址:https://netron.app/
方法二:查看论文结构图(Figure 2)
DETR 论文中提供了完整的模型结构图,展示了 ResNet + Transformer 的联合结构。
🔗 DETR 论文
十六、结语
DETR 是目前最先进的 端到端目标检测模型之一,它的核心技术亮点包括:
- 使用 Transformer 替代传统 CNN;
- 引入 Set Prediction 框架;
- 支持匈牙利匹配机制;
- 推理阶段消除了 NMS;
- 提供完整的部署支持(ONNX / TensorRT);
📌 欢迎点赞 + 收藏 + 关注我,我会持续更新更多关于 DETR、YOLO系列、深度学习等内容!
更多推荐
所有评论(0)