🔥 AI 即插即用 | 你的CV涨点模块“军火库”已开源!🔥

大家好!为了方便大家在CV科研和项目中高效涨点,我创建并维护了一个即插即用模块的GitHub代码仓库。

仓库里不仅有:

  • 核心模块即插即用代码
  • 论文精读总结
  • 架构图深度解析
  • 全文逐句翻译与应用实例

更有海量SOTA模型的创新模块汇总,致力于打造一个“AI即插即用”的百宝箱,方便大家快速实验、组合创新!

🚀 GitHub 仓库链接https://github.com/AITricks/AITricks

觉得有帮助的话,欢迎大家 Star, Fork, PR 一键三连,共同维护!

即插即用涨点系列 (十二):TDCNet 详解!AAAI 2026 SOTA,重构时空特征,TDCR 差分卷积与 TDCSTA 注意力刷新移动红外小目标检测精度

论文原文 (Paper)https://arxiv.org/abs/2511.09352
官方代码 (Code)https://github.com/IVPLaboratory/TDCNet

论文精读:TDCNet

1. 核心思想

本文针对移动红外小目标检测(Moving IRSTD)中弱目标特征和复杂背景干扰的挑战,提出了一种名为 TDCNet 的新型网络。其核心论点在于:现有的3D卷积虽然能提取时空特征,但缺乏对时间维度运动动态的“显式感知”;而时间差分法虽能捕捉运动,却丢失了空间语义。为此,论文创新性地提出了 时间差分卷积(TDC),将时间差分操作与3D卷积融合为统一的卷积表示,并通过重参数化技术(TDCR)在推理阶段实现零额外计算成本的多尺度运动上下文建模。此外,通过 TDC引导的时空注意力机制(TDCSTA),利用TDC提取的强运动线索来指导和增强时空特征的语义表达,从而在低信杂比(SCR)场景下实现SOTA性能。

2. 背景与动机

  • 文本角度总结
    移动红外小目标检测在无人机监控等领域至关重要,但面临目标微弱和背景杂波剧烈干扰的难题。现有方法主要分为两类:单帧方法(基于2D卷积)忽略了时间信息,容易产生虚警;多帧方法中,单纯的时间差分缺乏空间语义,而标准的3D卷积虽然建模了时空域,但对帧间微小的像素级变化(即运动线索)缺乏显式的归纳偏置。本文的动机是打破这种二元对立,设计一种既能像差分法那样显式捕捉运动,又能像3D卷积那样保留丰富时空特征的统一架构。

  • 动机图解分析(基于 Figure 1)
    在这里插入图片描述

    • 图表描述:Figure 1 展示了三种不同检测范式的对比。
    • (a)单帧方法(2D Conv):如图所示,输入单帧图像,经过2D网络后,特征图中背景杂波依然强烈,导致目标被淹没(Missed Detection)。这说明仅靠空间特征难以区分类似目标的背景噪声。
    • (b)多帧方法(3D Conv):输入多帧序列,使用标准3D卷积。虽然利用了时间信息,但特征图中目标依然不显著(Missed Detection)。这直观地揭示了标准3D卷积在提取微弱运动信号时的局限性——它“混合”了时空信息,却未“强调”变化。
    • (c)本文方法(TDC):引入时间差分卷积后,特征图中背景被极大地抑制(变黑),而运动目标的响应被显著保留(Detection Result)。
    • 总结:这组对比图清晰地指出了现有方法的**“语义鸿沟”(2D缺乏时间语义)和“运动感知瓶颈”**(3D缺乏显式运动建模),引出了本文利用TDC进行显式运动建模并抑制静态背景的核心解决思路。

3. 主要贡献点

  • [贡献点 1]:提出了TDCNet网络架构
    提出了一种全新的移动红外小目标检测网络。该网络并非简单地堆叠模块,而是构建了一个三流架构(2D空间流、3D时空流、TDC运动流),专门用于在抑制复杂背景的同时有效地提取和增强时空特征。

  • [贡献点 2]:设计了时间差分卷积重参数化(TDCR)模块
    这是本文的核心算子创新。作者设计了短时(S-TDC)、中时(M-TDC)和长时(L-TDC)三个并行分支,分别捕捉不同时间跨度的运动依赖。关键在于,这些分支在训练时独立工作,但在推理时通过重参数化技术等效融合为一个单一的3D卷积核。这使得模型具备了显式的多尺度差分能力,却保持了与普通3D卷积相同的推理计算量。

  • [贡献点 3]:提出了TDC引导的时空注意力(TDCSTA)机制
    设计了一种新颖的交叉注意力机制。不同于常规的特征融合,该模块利用 TDC Backbone 提取的特征(具备强运动感知、高信杂比)作为 Query,去查询和细化 3D Backbone(时空特征)和 2D Backbone(空间特征)中的语义信息。这种设计有效地建立了运动线索与全局语义之间的依赖关系,指导网络聚焦于关键的目标区域。

  • [贡献点 4]:构建了IRSTD-UAV基准数据集
    为了弥补现有数据集场景单一的缺陷,作者构建了一个包含15,106帧真实红外图像的新数据集。该数据集涵盖了多种类型的无人机目标和复杂的动态背景(如城市、树木、云层),为该领域提供了更具挑战性的评估基准。

4. 方法细节(最重要)

  • 整体网络架构(基于 Figure 2)

    结构图1

    • 输入流:网络接收一个帧序列和当前帧作为输入。
    • 骨干网络(Backbones):包含三条并行的路径:
      1. TDC Backbone:核心路径,由堆叠的 TDCR 层组成,专门从帧序列中提取显式的运动特征(TDC Features)。在此之前有背景对齐操作。
      2. 3D Backbone:处理帧序列,提取常规的时空特征(Spatio-Temporal Features)。
      3. 2D Backbone:仅处理当前帧,提取细粒度的空间特征(Spatial Features)。
    • 特征融合(TDCSTA):上述三路特征被送入 TDC-Guided Spatio-Temporal Attention 模块。
    • 输出层:增强后的特征(STEF)进入 Neck 网络进行聚合,最后由 Detection Head 输出检测结果。
  • 核心创新模块详解

    • 模块 A:时间差分卷积重参数化 (TDCR) 模块(基于 Figure 3 & Figure 4)

      结构图3

      • 内部结构:该模块在训练阶段由三个并行的分支组成:S-TDC(短时)、M-TDC(中时)和 L-TDC(长时)。

      • 数据流动与设计目的

        结构图4

        • L-TDC 为例(Figure 4b),它并不进行显式的减法操作。而是通过精心设计的卷积核权重配置,将“当前帧与所有过去帧的差分”这一数学操作内嵌到卷积运算中。
        • S-TDC 关注相邻帧差异,捕捉快速运动。
        • M-TDC 关注间隔帧差异,捕捉中速运动。
      • 重参数化机制:在推理阶段,利用卷积的线性特性,这三个分支(包括BN层)的参数被合并为一个标准的 5 × 3 × 3 5 \times 3 \times 3 5×3×3 3D卷积核。这实现了“训练时多尺度差分增强,推理时单卷积高效计算”。

    • 模块 B:TDC引导的时空注意力 (TDCSTA) 模块(基于 Figure 2)

      • 内部结构:包含自注意力(Self-Attention)和交叉注意力(Cross-Attention)两个阶段。
      • 数据流动
        1. 首先,TDC特征 ( T D C F TDCF TDCF)、时空特征 ( S T F STF STF) 和空间特征 ( S F SF SF) 分别通过自注意力机制增强自身的语义表达。
        2. 进入交叉注意力阶段:这是关键设计。将 TDCF 作为 Query (Q),而将 STF 作为 Key (K)SF 作为 Value (V)。(核心逻辑是用“运动特征”去检索“时空/空间特征”中的相关部分)。
      • 设计理念:由于 TDCF 经过差分卷积处理,背景杂波已被大幅抑制,目标区域显著(即“哪里有目标”的信息最准)。利用它作为 Query,可以指导网络从含噪量较大的 STF 和 SF 中准确地提取出目标的详细外观和时空信息,实现“运动指导语义”的特征细化。
  • 理念与机制总结
    本文的核心理念是 “显式差分卷积化”。传统的差分是 I n p u t t − I n p u t t − 1 Input_t - Input_{t-1} InputtInputt1,是无参数的预处理。本文提出的 TDC 将其转化为 W ∗ ( I n p u t t − I n p u t t − 1 ) W * (Input_t - Input_{t-1}) W(InputtInputt1),并进一步重构为 W ′ ∗ I n p u t S e q u e n c e W' * Input_{Sequence} WInputSequence

    • 公式解读:论文展示了如何通过学习一组特定的权重,使得卷积操作的输出等价于对“差分特征图”的卷积。
    • 工作机制:这种机制赋予了卷积核“并在”时空域中直接感知“变化量”的能力,强行引入了由运动产生的归纳偏置,从而在特征提取的早期阶段就有效地过滤掉静态背景。
  • 图解总结
    论文的设计通过图解展示了完美的协同工作:TDCR模块(Figure 3) 解决了“如何高效提取纯净运动特征”的问题(对应解决 Figure 1 中背景干扰大的痛点);TDCSTA模块(Figure 2) 解决了“如何利用运动特征找回丢失的细节”的问题。两者结合,前者负责“抑制背景”,后者负责“增强目标”,共同实现了复杂背景下的高灵敏度检测。

5. 即插即用模块的作用

本文提出的创新点具有很强的通用性,可作为即插即用模块应用于多种视频分析任务:

  1. TDCR (Temporal Difference Convolution Re-parameterization) 模块

    • 适用场景:任何基于3D卷积或视频序列的特征提取任务,特别是对运动敏感但计算资源受限的场景。
    • 具体应用
      • 视频动作识别:替换现有的 C3D 或 (2+1)D 卷积层,增强网络对细微动作(如手势识别)的捕捉能力,且不增加推理延时。
      • 视频显著性目标检测:用于在动态背景下快速定位移动的前景对象。
      • 视频异常检测:利用其对运动模式的敏感性,检测监控视频中的异常运动事件(如突然奔跑、跌倒)。
  2. TDCSTA (TDC-Guided Spatio-Temporal Attention) 机制

    • 适用场景:多模态或多流网络架构,其中一个流具有高信噪比(如运动流、深度流)但语义弱,另一个流语义强但噪声大。
    • 具体应用
      • RGB-红外/RGB-热成像 融合跟踪:利用红外流(对热源敏感)作为 Query,指导 RGB 流(纹理丰富)的特征提取,提高全天候跟踪性能。
      • 视频去雨/去雾:利用差分特征作为 Query,指导网络关注雨滴/雾气的动态区域,从而更精准地恢复背景细节。

6. 即插即用模块

"""
Collection of plug-and-play modules distilled from the original TDCNet
implementation. These modules correspond to the key building blocks shown
across Figures 1-4 of the paper:

- Temporal Difference Convolution (TDC) block
- Temporal Difference Convolution Re-parameterization (TDCR / RepConv3D)
- TDC-guided Self-Attention and Cross-Attention modules

They are copied from the source files under ``model/TDCNet`` so they can be
imported independently wherever lightweight reuse is needed.
"""

from functools import reduce
from operator import mul
import time

import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------------------------------------------------------------------------
# Temporal Difference Convolution (Figure 4)
# ---------------------------------------------------------------------------
class TDC(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=(5, 3, 3),
        stride=1,
        padding=(2, 1, 1),
        groups=1,
        bias=False,
        step=1,
    ):
        super().__init__()
        self.conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            bias=bias,
        )
        self.step = step
        self.groups = groups

    def get_time_gradient_weight(self):
        weight = self.conv.weight
        kT, kH, kW = weight.shape[2:]
        grad_weight = torch.zeros_like(weight, device=weight.device, dtype=weight.dtype)
        if kT == 5:
            if self.step == -1:
                grad_weight[:, :, :, :, :] = -weight[:, :, :, :, :]
                grad_weight[:, :, 4, :, :] = (
                    weight[:, :, 0, :, :]
                    + weight[:, :, 1, :, :]
                    + weight[:, :, 2, :, :]
                    + weight[:, :, 3, :, :]
                    + weight[:, :, 4, :, :]
                )
            elif self.step == 1:
                grad_weight[:, :, 4, :, :] = weight[:, :, 4, :, :]
                grad_weight[:, :, 3, :, :] = weight[:, :, 3, :, :] - weight[:, :, 4, :, :]
                grad_weight[:, :, 2, :, :] = weight[:, :, 2, :, :] - weight[:, :, 3, :, :]
                grad_weight[:, :, 1, :, :] = weight[:, :, 1, :, :] - weight[:, :, 2, :, :]
                grad_weight[:, :, 0, :, :] = -weight[:, :, 1, :, :]
            elif self.step == 2:
                grad_weight[:, :, 4, :, :] = weight[:, :, 4, :, :]
                grad_weight[:, :, 3, :, :] = weight[:, :, 3, :, :]
                grad_weight[:, :, 2, :, :] = weight[:, :, 2, :, :] - weight[:, :, 4, :, :]
                grad_weight[:, :, 1, :, :] = -weight[:, :, 3, :, :]
                grad_weight[:, :, 0, :, :] = -weight[:, :, 2, :, :]
        else:
            grad_weight = weight
        bias = self.conv.bias
        if bias is None:
            bias = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype)
        return grad_weight, bias

    def forward(self, x):
        weight, bias = self.get_time_gradient_weight()
        x_diff = F.conv3d(
            x,
            weight,
            bias,
            stride=self.conv.stride,
            groups=self.groups,
            padding=self.conv.padding,
        )
        return x_diff


# ---------------------------------------------------------------------------
# Temporal Difference Convolution Re-parameterization (Figure 3)
# ---------------------------------------------------------------------------
class RepConv3D(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=(5, 3, 3),
        stride=1,
        padding=(2, 1, 1),
        groups=1,
        deploy=False,
    ):
        super(RepConv3D, self).__init__()
        self.deploy = deploy
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.groups = groups
        if self.deploy:
            self.conv_reparam = nn.Conv3d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                groups=groups,
                bias=True,
            )
        else:
            self.l_tdc = nn.Sequential(
                TDC(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride,
                    padding,
                    groups=groups,
                    bias=False,
                    step=-1,
                ),
                nn.BatchNorm3d(out_channels),
            )
            self.s_tdc = nn.Sequential(
                TDC(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride,
                    padding,
                    groups=groups,
                    bias=False,
                    step=1,
                ),
                nn.BatchNorm3d(out_channels),
            )
            self.m_tdc = nn.Sequential(
                TDC(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride,
                    padding,
                    groups=groups,
                    bias=False,
                    step=2,
                ),
                nn.BatchNorm3d(out_channels),
            )

    def forward(self, x):
        if self.deploy:
            out = F.relu(self.conv_reparam(x))
        else:
            out = self.s_tdc(x) + self.m_tdc(x) + self.l_tdc(x)
            out = F.relu(out)
        return out

    def get_equivalent_kernel_bias(self):
        kernel_s_tdc, bias_s_tdc = self._fuse_conv_bn(self.s_tdc)
        kernel_m_tdc, bias_m_tdc = self._fuse_conv_bn(self.m_tdc)
        kernel_l_tdc, bias_l_tdc = self._fuse_conv_bn(self.l_tdc)
        kernel = kernel_s_tdc + kernel_m_tdc + kernel_l_tdc
        bias = bias_s_tdc + bias_m_tdc + bias_l_tdc
        return kernel, bias

    def switch_to_deploy(self):
        if self.deploy:
            return
        kernel, bias = self.get_equivalent_kernel_bias()
        self.conv_reparam = nn.Conv3d(
            self.in_channels,
            self.out_channels,
            (5, 3, 3),
            self.stride,
            (2, 1, 1),
            groups=self.groups,
            bias=True,
        )
        self.conv_reparam.weight.data = kernel
        self.conv_reparam.bias.data = bias
        self.deploy = True
        del self.s_tdc
        del self.m_tdc
        del self.l_tdc

    @staticmethod
    def _fuse_conv_bn(branch):
        if branch is None:
            return 0, 0

        def find_conv(module):
            if isinstance(module, nn.Conv3d):
                return module
            for child in module.children():
                conv = find_conv(child)
                if conv is not None:
                    return conv
            return None

        conv = find_conv(branch[0])
        bn = branch[1]
        if hasattr(branch[0], "get_time_gradient_weight"):
            w, bias = branch[0].get_time_gradient_weight()
        else:
            w = conv.weight
            if conv.bias is not None:
                bias = conv.bias
            else:
                bias = torch.zeros_like(bn.running_mean)
        mean = bn.running_mean
        var_sqrt = torch.sqrt(bn.running_var + bn.eps)
        gamma = bn.weight
        beta = bn.bias
        w = w * (gamma / var_sqrt).reshape(-1, 1, 1, 1, 1)
        bias = (bias - mean) / var_sqrt * gamma + beta
        return w, bias


# ---------------------------------------------------------------------------
# TDC-Guided Spatio-Temporal Attention (Figure 2)
# ---------------------------------------------------------------------------
class WindowAttention3D(nn.Module):
    def __init__(
        self,
        dim,
        window_size,
        num_heads,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (T, H, W)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(
                (2 * window_size[0] - 1)
                * (2 * window_size[1] - 1)
                * (2 * window_size[2] - 1),
                num_heads,
            )
        )
        coords_t = torch.arange(self.window_size[0])
        coords_h = torch.arange(self.window_size[1])
        coords_w = torch.arange(self.window_size[2])
        coords = torch.stack(torch.meshgrid(coords_t, coords_h, coords_w, indexing="ij"))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 2] += self.window_size[2] - 1
        relative_coords[:, :, 0] *= (
            2 * self.window_size[1] - 1
        ) * (2 * self.window_size[2] - 1)
        relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x, k=None, v=None, mask=None):
        B_, N, C = x.shape
        if k is None or v is None:
            qkv = (
                self.qkv(x)
                .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
                .permute(2, 0, 3, 1, 4)
            )
            q, k, v = qkv[0], qkv[1], qkv[2]
        else:
            q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
            k = k.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
            v = v.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index[:N, :N].reshape(-1)
        ].reshape(N, N, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


def window_partition(x, window_size):
    B, T, H, W, C = x.shape
    window_size = list(window_size)
    if T < window_size[0]:
        window_size[0] = T
    if H < window_size[1]:
        window_size[1] = H
    if W < window_size[2]:
        window_size[2] = W
    x = x.view(
        B,
        T // window_size[0] if window_size[0] > 0 else 1,
        window_size[0],
        H // window_size[1] if window_size[1] > 0 else 1,
        window_size[1],
        W // window_size[2] if window_size[2] > 0 else 1,
        window_size[2],
        C,
    )
    windows = (
        x.permute(0, 1, 3, 5, 2, 4, 6, 7)
        .contiguous()
        .view(-1, reduce(mul, window_size), C)
    )
    return windows


def window_reverse(windows, window_size, B, T, H, W):
    x = windows.view(
        B,
        T // window_size[0],
        H // window_size[1],
        W // window_size[2],
        window_size[0],
        window_size[1],
        window_size[2],
        -1,
    )
    x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, T, H, W, -1)
    return x


def get_window_size(x_size, window_size, shift_size=None):
    use_window_size = list(window_size)
    if shift_size is not None:
        use_shift_size = list(shift_size)
    for i in range(len(x_size)):
        if x_size[i] <= window_size[i]:
            use_window_size[i] = x_size[i]
            if shift_size is not None:
                use_shift_size[i] = 0
    if shift_size is None:
        return tuple(use_window_size)
    else:
        return tuple(use_window_size), tuple(use_shift_size)


class SelfAttention(nn.Module):
    def __init__(
        self,
        dim,
        window_size=(2, 8, 8),
        num_heads=8,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        use_shift=False,
        shift_size=None,
        mlp_ratio=2.0,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.use_shift = use_shift
        self.shift_size = (
            shift_size
            if shift_size is not None
            else tuple([w // 2 for w in window_size]) if use_shift else tuple([0] * len(window_size))
        )
        self.attn1 = WindowAttention3D(
            dim,
            window_size=self.window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
        )
        self.attn2 = WindowAttention3D(
            dim,
            window_size=self.window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
        )
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.norm3 = norm_layer(dim)
        self.norm4 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp1 = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
        )
        self.mlp2 = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
        )

    def create_mask(self, x_shape, device):
        B, T, H, W, C = x_shape
        img_mask = torch.zeros((1, T, H, W, 1), device=device)
        cnt = 0
        t_slices = (
            slice(0, -self.window_size[0]),
            slice(-self.window_size[0], -self.shift_size[0]),
            slice(-self.shift_size[0], None),
        )
        h_slices = (
            slice(0, -self.window_size[1]),
            slice(-self.window_size[1], -self.shift_size[1]),
            slice(-self.shift_size[1], None),
        )
        w_slices = (
            slice(0, -self.window_size[2]),
            slice(-self.window_size[2], -self.shift_size[2]),
            slice(-self.shift_size[2], None),
        )
        for t in t_slices:
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, t, h, w, :] = cnt
                    cnt += 1
        mask_windows = window_partition(img_mask, self.window_size)
        mask_windows = mask_windows.squeeze(-1)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
            attn_mask == 0, float(0.0)
        )
        return attn_mask

    def forward(self, x):
        B, T, H, W, C = x.shape
        window_size, shift_size = get_window_size((T, H, W), self.window_size, self.shift_size)
        shortcut = x
        x = self.norm1(x)
        pad_t = (window_size[0] - T % window_size[0]) % window_size[0]
        pad_h = (window_size[1] - H % window_size[1]) % window_size[1]
        pad_w = (window_size[2] - W % window_size[2]) % window_size[2]
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
        shortcut = F.pad(shortcut, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
        _, Tp, Hp, Wp, _ = x.shape
        x_windows = window_partition(x, window_size)
        attn_windows = self.attn1(x_windows, mask=None)
        attn_windows = attn_windows.view(-1, *(window_size + (C,)))
        x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp)
        x = shortcut + x
        x = x + self.mlp1(self.norm2(x))
        shortcut = x
        x = self.norm3(x)
        if self.use_shift and any(i > 0 for i in shift_size):
            shifted_x = torch.roll(
                x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)
            )
            attn_mask = self.create_mask((B, Tp, Hp, Wp, C), x.device)
            x_windows = window_partition(shifted_x, window_size)
            attn_windows = self.attn2(x_windows, mask=attn_mask)
            attn_windows = attn_windows.view(-1, *(window_size + (C,)))
            shifted_x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp)
            x = torch.roll(
                shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)
            )
        if pad_t > 0:
            x = x[:, :T, :, :, :]
            shortcut = shortcut[:, :T, :, :, :]
        if pad_h > 0:
            x = x[:, :, :H, :, :]
            shortcut = shortcut[:, :, :H, :, :]
        if pad_w > 0:
            x = x[:, :, :, :W, :]
            shortcut = shortcut[:, :, :, :W, :]

        x = shortcut + x
        x = x + self.mlp2(self.norm4(x))
        return x


class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        window_size=(2, 8, 8),
        num_heads=8,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        mlp_ratio=2.0,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.norm1_q = norm_layer(dim)
        self.norm1_k = norm_layer(dim)
        self.norm1_v = norm_layer(dim)
        self.attn = WindowAttention3D(
            dim,
            window_size=self.window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
        )
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
        )

    def forward(self, q, k, v):
        B, T, H, W, C = q.shape
        window_size = get_window_size((T, H, W), self.window_size)
        shortcut = v
        q = self.norm1_q(q)
        k = self.norm1_k(k)
        v = self.norm1_v(v)
        pad_t = (window_size[0] - T % window_size[0]) % window_size[0]
        pad_h = (window_size[1] - H % window_size[1]) % window_size[1]
        pad_w = (window_size[2] - W % window_size[2]) % window_size[2]
        q = F.pad(q, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
        k = F.pad(k, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
        v = F.pad(v, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))
        _, Tp, Hp, Wp, _ = q.shape

        q_windows = window_partition(q, window_size)
        k_windows = window_partition(k, window_size)
        v_windows = window_partition(v, window_size)
        attn_windows = self.attn(q_windows, k_windows, v_windows)
        attn_windows = attn_windows.view(-1, *(window_size + (C,)))
        shifted_x = window_reverse(attn_windows, window_size, B, Tp, Hp, Wp)
        x = shifted_x
        if pad_t > 0:
            x = x[:, :T, :, :, :]
        if pad_h > 0:
            x = x[:, :, :H, :, :]
        if pad_w > 0:
            x = x[:, :, :, :W, :]
        x = shortcut + x
        x = x + self.mlp(self.norm2(x))
        return x


# ---------------------------------------------------------------------------
# Simple benchmark entry point
# ---------------------------------------------------------------------------
def _benchmark_module(name, module, *inputs, warmup=5, iters=20):
    module.eval()
    with torch.no_grad():
        for _ in range(warmup):
            module(*inputs)
        start = time.time()
        for _ in range(iters):
            module(*inputs)
        end = time.time()
    avg = (end - start) / iters * 1000
    print(f"{name}: {avg:.3f} ms/iter over shape {[tuple(i.shape) for i in inputs]}")


if __name__ == "__main__":
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dummy_video = torch.randn(1, 64, 5, 32, 32).to(device)
    rep = RepConv3D(64, 128).to(device)
    _benchmark_module("RepConv3D", rep, dummy_video)

    # Attention modules expect [B, T, H, W, C]
    dummy_feat = torch.randn(1, 4, 16, 16, 128).to(device)
    sa = SelfAttention(128, window_size=(2, 4, 4), num_heads=4, use_shift=True).to(device)
    _benchmark_module("SelfAttention", sa, dummy_feat)

    ca = CrossAttention(128, window_size=(2, 4, 4), num_heads=4).to(device)
    _benchmark_module("CrossAttention", ca, dummy_feat, dummy_feat, dummy_feat)


Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐