告别灾难性遗忘!CoMBO让你的AI模型新老知识两不误,41.1% mIoU碾压全场!
VGGT是一个统一的视觉通用模型,通过引导式Transformer架构同时处理点跟踪、深度估计、相机位姿估计和光流预测等多项视觉任务。其核心机制包括交替注意力(帧内和全局注意力)、专用Token(相机和寄存器Token)以及多任务预测头。代码实现展示了关键组件,如多头注意力模块和密集预测头。该模型通过单一架构实现多任务处理,无需为每个任务设计专用解码器,具有高效建模时空依赖关系的能力。
VGGT模块
论文《VGGT: A Visual Generalist with Guided Transformers》
论文地址: https://arxiv.org/pdf/2504.04156
深度交流Q裙:1051849847
全网同名 【大嘴带你水论文】 B站定时发布详细讲解视频
核心思想 : 构建一个统一的、端到端的视觉模型,能够同时处理包括点跟踪、深度估计、相机位姿估计和光流预测在内的多种视觉任务。
详细代码见文章最后
1、作用
VGGT (Visual Generalist with Guided Transformers) 是一个功能强大的多任务视觉通用模型。它旨在通过单一的、统一的架构解决多样化的视觉任务,而无需为每个任务设计专门的解码器。VGGT的核心能力在于其可以从视频序列中学习时空表征,并利用这些表征来执行精确的3D感知和跟踪。无论是对于需要理解场景几何的深度估计,还是需要捕捉动态变化的运动跟踪,VGGT都能够提供高质量的预测结果,使其成为一个真正的“视觉多面手”。
图1:VGGT框架图
2、机制
VGGT的强大功能源于其精心设计的引导式Transformer架构,该架构能够高效地融合多帧图像信息,并生成适用于不同任务的丰富特征。
- 引导式交替注意力 (Guided Alternating Attention) :
这是VGGT的核心机制。模型没有采用传统的将所有token一次性输入Transformer的方式,而是设计了一种交替注意力(Alternating Attention)机制。该机制包含两种注意力模式:
- 帧内注意力 (Frame Attention) : 在单帧图像的patch token之间计算自注意力,捕捉空间上下文信息。
- 全局注意力 (Global Attention) : 在所有帧的patch token之间计算自注意力,捕捉跨时间的时序关联信息。
这两种注意力模式交替进行,使得模型能够以较低的计算成本高效地建模长程时空依赖关系。
- 专用Token与位置编码 :
为了更好地引导模型,VGGT引入了两种特殊的Token:
- 相机Token (Camera Token) : 编码了每一帧的相机内外参信息(如位置、姿态、焦距),为模型提供了精确的几何信息。
- 寄存器Token (Register Token) : 作为可学习的“暂存器”,用于在不同注意力模块之间传递和汇总信息。
此外,模型还使用了二维旋转位置编码(RoPE),帮助网络更好地理解patch之间的相对空间位置。
- 多任务预测头 (Multi-task Prediction Heads) :
VGGT的骨干网络(Aggregator)输出统一的时空特征,这些特征被送入不同的预测头以完成特定任务:
- DPT头 (Dense Prediction Transformer Head) : 用于生成密集的预测结果,如深度图。它通过多尺度特征融合模块,将来自骨干网络不同层级的特征进行整合,以生成高分辨率、高精度的深度预测。
- 跟踪头 (Track Head) : 负责预测点的轨迹。它利用DPT头提取的特征,并通过一个迭代优化的跟踪器( BaseTrackerPredictor )来精确地估计点在连续帧中的运动轨迹。
3、代码实现
import torch
import torch.nn as nn
import numpy as np
from typing import Union, List, Tuple
# 定义数组类型
ArrayLike = Union[np.ndarray, torch.Tensor]
def _ensure_torch(x: ArrayLike) -> torch.Tensor:
"""确保输入是torch张量"""
if isinstance(x, np.ndarray):
return torch.from_numpy(x)
elif isinstance(x, torch.Tensor):
return x
else:
return torch.tensor(x)
class Attention(nn.Module):
"""
多头注意力模块
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
assert dim % num_heads == 0, "维度应能被头数整除"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class DPTHead(nn.Module):
"""
用于密集预测任务的DPT头
"""
def __init__(
self,
dim_in: int,
patch_size: int = 14,
output_dim: int = 4,
features: int = 256,
out_channels: List[int] = [256, 512, 1024, 1024],
) -> None:
super(DPTHead, self).__init__()
self.patch_size = patch_size
self.projects = nn.ModuleList(
[nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
)
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0),
nn.ConvTranspose2d(in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0),
nn.Identity(),
nn.Conv2d(in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1),
]
)
# 此处省略了复杂的scratch和fusion block的实现,以保持代码简洁
self.output_conv = nn.Conv2d(features, output_dim, kernel_size=1, stride=1, padding=0)
def forward(self, aggregated_tokens_list: List[torch.Tensor], images: torch.Tensor, patch_start_idx: int) -> torch.Tensor:
B, S, _, H, W = images.shape
patch_h, patch_w = H // self.patch_size, W // self.patch_size
# 此处简化了前向传播逻辑
# 实际实现中会融合多尺度特征
tokens = aggregated_tokens_list[-1][:, patch_start_idx:]
tokens = tokens.reshape(B*S, patch_h, patch_w, -1).permute(0, 3, 1, 2)
# 假设直接使用最后一层token进行预测
projected_tokens = self.projects[-1](tokens)
resized_tokens = self.resize_layers[-1](projected_tokens)
# 上采样到原始分辨率
upsampled = nn.functional.interpolate(resized_tokens, size=(H, W), mode='bilinear', align_corners=False)
output = self.output_conv(upsampled.view(B*S, -1, H, W))
return output.view(B, S, -1, H, W)
class TrackHead(nn.Module):
"""
追踪头,使用DPT进行特征提取
"""
def __init__(self, dim_in, patch_size=14, features=128, iters=4):
super().__init__()
self.patch_size = patch_size
self.feature_extractor = DPTHead(dim_in=dim_in, patch_size=patch_size, features=features, feature_only=True)
# 此处省略了BaseTrackerPredictor的实现
self.iters = iters
def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None):
feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
# 追踪逻辑...
# 返回追踪坐标和可见性
return torch.randn_like(query_points), torch.ones_like(query_points)[..., 0]
class VGGT(nn.Module):
"""
VGGT模型: 一个用于视频的通用几何Transformer
"""
def __init__(self, img_size=518, patch_size=14, embed_dim=1024,
enable_camera=True, enable_point=True, enable_depth=True, enable_track=True):
super().__init__()
# 此处省略了Aggregator的实现
# self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
self.camera_head = nn.Identity() if enable_camera else None
self.point_head = DPTHead(dim_in=embed_dim, output_dim=4) if enable_point else None
self.depth_head = DPTHead(dim_in=embed_dim, output_dim=2) if enable_depth else None
self.track_head = TrackHead(dim_in=embed_dim, patch_size=patch_size) if enable_track else None
def forward(self, images: torch.Tensor, query_points: torch.Tensor = None):
if len(images.shape) == 4:
images = images.unsqueeze(0)
# 伪造aggregator的输出
B, S, _, H, W = images.shape
patch_h, patch_w = H // 14, W // 14
# 假设aggregator输出一个token列表
aggregated_tokens_list = [torch.randn(B*S, patch_h*patch_w + 1, 1024)]
patch_start_idx = 1
predictions = {}
if self.depth_head is not None:
depth, depth_conf = self.depth_head(aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx).chunk(2, dim=2)
predictions["depth"] = depth
predictions["depth_conf"] = depth_conf
if self.track_head is not None and query_points is not None:
track, vis = self.track_head(aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points)
predictions["track"] = track
predictions["vis"] = vis
return predictions
if __name__ == '__main__':
# 模型参数
img_size = 224
patch_size = 14
embed_dim = 1024
batch_size = 1
seq_length = 5
num_points = 10
# 创建模型
model = VGGT(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
# 创建输入
images = torch.randn(batch_size, seq_length, 3, img_size, img_size)
query_points = torch.randn(batch_size, num_points, 2)
# 前向传播
predictions = model(images, query_points)
# 打印结果
print('输入图像尺寸:', images.size())
print('输入查询点尺寸:', query_points.size())
if 'depth' in predictions:
print('预测深度图尺寸:', predictions['depth'].size())
if 'track' in predictions:
print('预测轨迹尺寸:', predictions['track'].size())
print(f"参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M")
详细代码 gitcode地址:https://gitcode.com/2301_80107842/research
更多推荐
所有评论(0)