VGGT模块

论文《VGGT: A Visual Generalist with Guided Transformers》
论文地址: https://arxiv.org/pdf/2504.04156
深度交流Q裙:1051849847
全网同名 【大嘴带你水论文】 B站定时发布详细讲解视频
核心思想 : 构建一个统一的、端到端的视觉模型,能够同时处理包括点跟踪、深度估计、相机位姿估计和光流预测在内的多种视觉任务。
详细代码见文章最后

在这里插入图片描述

1、作用

VGGT (Visual Generalist with Guided Transformers) 是一个功能强大的多任务视觉通用模型。它旨在通过单一的、统一的架构解决多样化的视觉任务,而无需为每个任务设计专门的解码器。VGGT的核心能力在于其可以从视频序列中学习时空表征,并利用这些表征来执行精确的3D感知和跟踪。无论是对于需要理解场景几何的深度估计,还是需要捕捉动态变化的运动跟踪,VGGT都能够提供高质量的预测结果,使其成为一个真正的“视觉多面手”。

在这里插入图片描述
图1:VGGT框架图

2、机制

VGGT的强大功能源于其精心设计的引导式Transformer架构,该架构能够高效地融合多帧图像信息,并生成适用于不同任务的丰富特征。

  1. 引导式交替注意力 (Guided Alternating Attention) :
    这是VGGT的核心机制。模型没有采用传统的将所有token一次性输入Transformer的方式,而是设计了一种交替注意力(Alternating Attention)机制。该机制包含两种注意力模式:
  • 帧内注意力 (Frame Attention) : 在单帧图像的patch token之间计算自注意力,捕捉空间上下文信息。
  • 全局注意力 (Global Attention) : 在所有帧的patch token之间计算自注意力,捕捉跨时间的时序关联信息。
    这两种注意力模式交替进行,使得模型能够以较低的计算成本高效地建模长程时空依赖关系。
  1. 专用Token与位置编码 :
    为了更好地引导模型,VGGT引入了两种特殊的Token:
  • 相机Token (Camera Token) : 编码了每一帧的相机内外参信息(如位置、姿态、焦距),为模型提供了精确的几何信息。
  • 寄存器Token (Register Token) : 作为可学习的“暂存器”,用于在不同注意力模块之间传递和汇总信息。
    此外,模型还使用了二维旋转位置编码(RoPE),帮助网络更好地理解patch之间的相对空间位置。
  1. 多任务预测头 (Multi-task Prediction Heads) :
    VGGT的骨干网络(Aggregator)输出统一的时空特征,这些特征被送入不同的预测头以完成特定任务:
  • DPT头 (Dense Prediction Transformer Head) : 用于生成密集的预测结果,如深度图。它通过多尺度特征融合模块,将来自骨干网络不同层级的特征进行整合,以生成高分辨率、高精度的深度预测。
  • 跟踪头 (Track Head) : 负责预测点的轨迹。它利用DPT头提取的特征,并通过一个迭代优化的跟踪器( BaseTrackerPredictor )来精确地估计点在连续帧中的运动轨迹。

3、代码实现

import torch
import torch.nn as nn
import numpy as np
from typing import Union, List, Tuple

# 定义数组类型
ArrayLike = Union[np.ndarray, torch.Tensor]

def _ensure_torch(x: ArrayLike) -> torch.Tensor:
    """确保输入是torch张量"""
    if isinstance(x, np.ndarray):
        return torch.from_numpy(x)
    elif isinstance(x, torch.Tensor):
        return x
    else:
        return torch.tensor(x)

class Attention(nn.Module):
    """
    多头注意力模块
    """
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = True,
        proj_bias: bool = True,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, "维度应能被头数整除"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class DPTHead(nn.Module):
    """
    用于密集预测任务的DPT头
    """
    def __init__(
        self,
        dim_in: int,
        patch_size: int = 14,
        output_dim: int = 4,
        features: int = 256,
        out_channels: List[int] = [256, 512, 1024, 1024],
    ) -> None:
        super(DPTHead, self).__init__()
        self.patch_size = patch_size

        self.projects = nn.ModuleList(
            [nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
        )

        self.resize_layers = nn.ModuleList(
            [
                nn.ConvTranspose2d(in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0),
                nn.ConvTranspose2d(in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0),
                nn.Identity(),
                nn.Conv2d(in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1),
            ]
        )
        # 此处省略了复杂的scratch和fusion block的实现,以保持代码简洁
        self.output_conv = nn.Conv2d(features, output_dim, kernel_size=1, stride=1, padding=0)


    def forward(self, aggregated_tokens_list: List[torch.Tensor], images: torch.Tensor, patch_start_idx: int) -> torch.Tensor:
        B, S, _, H, W = images.shape
        patch_h, patch_w = H // self.patch_size, W // self.patch_size
        
        # 此处简化了前向传播逻辑
        # 实际实现中会融合多尺度特征
        tokens = aggregated_tokens_list[-1][:, patch_start_idx:]
        tokens = tokens.reshape(B*S, patch_h, patch_w, -1).permute(0, 3, 1, 2)
        
        # 假设直接使用最后一层token进行预测
        projected_tokens = self.projects[-1](tokens)
        resized_tokens = self.resize_layers[-1](projected_tokens)
        
        # 上采样到原始分辨率
        upsampled = nn.functional.interpolate(resized_tokens, size=(H, W), mode='bilinear', align_corners=False)
        
        output = self.output_conv(upsampled.view(B*S, -1, H, W))
        
        return output.view(B, S, -1, H, W)


class TrackHead(nn.Module):
    """
    追踪头,使用DPT进行特征提取
    """
    def __init__(self, dim_in, patch_size=14, features=128, iters=4):
        super().__init__()
        self.patch_size = patch_size
        self.feature_extractor = DPTHead(dim_in=dim_in, patch_size=patch_size, features=features, feature_only=True)
        # 此处省略了BaseTrackerPredictor的实现
        self.iters = iters

    def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None):
        feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
        # 追踪逻辑...
        # 返回追踪坐标和可见性
        return torch.randn_like(query_points), torch.ones_like(query_points)[..., 0]

class VGGT(nn.Module):
    """
    VGGT模型: 一个用于视频的通用几何Transformer
    """
    def __init__(self, img_size=518, patch_size=14, embed_dim=1024,
                 enable_camera=True, enable_point=True, enable_depth=True, enable_track=True):
        super().__init__()
        # 此处省略了Aggregator的实现
        # self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
        self.camera_head = nn.Identity() if enable_camera else None
        self.point_head = DPTHead(dim_in=embed_dim, output_dim=4) if enable_point else None
        self.depth_head = DPTHead(dim_in=embed_dim, output_dim=2) if enable_depth else None
        self.track_head = TrackHead(dim_in=embed_dim, patch_size=patch_size) if enable_track else None

    def forward(self, images: torch.Tensor, query_points: torch.Tensor = None):
        if len(images.shape) == 4:
            images = images.unsqueeze(0)
        
        # 伪造aggregator的输出
        B, S, _, H, W = images.shape
        patch_h, patch_w = H // 14, W // 14
        
        # 假设aggregator输出一个token列表
        aggregated_tokens_list = [torch.randn(B*S, patch_h*patch_w + 1, 1024)]
        patch_start_idx = 1

        predictions = {}

        if self.depth_head is not None:
            depth, depth_conf = self.depth_head(aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx).chunk(2, dim=2)
            predictions["depth"] = depth
            predictions["depth_conf"] = depth_conf

        if self.track_head is not None and query_points is not None:
            track, vis = self.track_head(aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points)
            predictions["track"] = track
            predictions["vis"] = vis

        return predictions

if __name__ == '__main__':
    # 模型参数
    img_size = 224
    patch_size = 14
    embed_dim = 1024
    batch_size = 1
    seq_length = 5
    num_points = 10

    # 创建模型
    model = VGGT(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)

    # 创建输入
    images = torch.randn(batch_size, seq_length, 3, img_size, img_size)
    query_points = torch.randn(batch_size, num_points, 2)

    # 前向传播
    predictions = model(images, query_points)

    # 打印结果
    print('输入图像尺寸:', images.size())
    print('输入查询点尺寸:', query_points.size())
    if 'depth' in predictions:
        print('预测深度图尺寸:', predictions['depth'].size())
    if 'track' in predictions:
        print('预测轨迹尺寸:', predictions['track'].size())
    
    print(f"参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M")

详细代码 gitcode地址:https://gitcode.com/2301_80107842/research

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐