VX: shixiaodayyds,备注【即插即用】,添加即插即用模块交流群。


模块出处

在这里插入图片描述

Paper:Mesoscopic Insights: Orchestrating Multi-scale & Hybrid Architecture for Image
Manipulation Localization

Code:https://github.com/scu-zjz/Mesorch

模块介绍

在这里插入图片描述
Mesoscopic-Orchestration(Mesorch):

  • 将 Transformer 和 CNN 并行结合,Transformers 提取宏观信息和 CNN 捕获微观细节。
  • 在不同尺度上探索,无缝评估微观和宏观信息。

现有频率域方法(如直接 FFT/DCT 后丢弃高频)存在不足:

  • 固定频率划分:手动选择频率阈值(如仅保留前 10% 低频),无法自适应不同任务需求(如分割需更多高频,分类需更多低频);
  • 缺乏归一化:频率域特征数值范围波动大(从 - 1e3 到 1e3),直接输入后续模块易导致梯度爆炸;
  • 无动态适配:DCT 矩阵固定为单一尺寸,无法适配不同输入分辨率(如 32×32、64×64 特征图)。

Mesorch 通过三大创新解决上述问题:

  1. 自适应频率划分:通过alpha参数灵活控制高低频比例(如高频alpha=0.05保留 95% 高频,低频alpha=0.95保留 95% 低频),适配不同任务;
  2. 动态 DCT 矩阵:根据输入特征尺寸(H/W)实时生成对应大小的 DCT 矩阵,支持任意正方形特征图;
  3. 全局归一化:将逆 DCT 后的特征映射到 0~1,避免数值波动,提升后续模块稳定性。

模块提出的动机(Motivation)

介观水平是宏观世界和微观世界之间的桥梁,解决了两者忽略的差距。图像处理定位(IML)是一种从假图像中追求真理的关键技术,长期以来一直依赖于低级(微观层面的)痕迹。然而,在实践中,大多数篡改旨在通过改变图像语义来欺骗受众。因此,操作通常发生在对象级别(宏观级别),这与微观跟踪同样重要。因此,将这两个级别集成到介观级别为 IML 研究提供了新的视角。

因此,引入Mesorach架构来协调两者,解决传统视觉模型在 “细节捕捉” 与 “全局建模” 上的矛盾。

适用范围与模块效果

适用范围:适用于通用视觉领域,特别是需要高低频分离,频率差距明显的数据集上的任务。

模块优劣在这里插入图片描述

缝合位置:需要特征提取的位置。

模块效果:复杂度更低,架构更合理,性能更优。
在这里插入图片描述
在这里插入图片描述
消融:CNN对应高频DCT,Transformer对应低频DCT更优。这与CNN和Transformer本身的特性对应。在这里插入图片描述

模块代码及使用方式

代码与模块结构图对应关系:

在这里插入图片描述

模块代码(详细注释与特征流前向传播过程中的维度变化):

import torch
import torch.nn as nn
import torch.fft
import math

class HighDctFrequencyExtractor(nn.Module):
    """
    高频DCT频率提取模块:通过二维离散余弦变换(DCT)将特征从空间域转换到频率域,
    保留高频分量并抑制低频分量,最终将高频特征映射回空间域并归一化
    核心作用:捕捉特征中的细节信息(如边缘、纹理),适配需要精细特征的任务
    """

    def __init__(self, alpha=0.05):
        """
        Args:
            alpha (float): 低频抑制比例(0<alpha<1),默认0.05表示抑制前5%的低频区域
                           alpha越小,保留的高频区域越大;alpha越大,仅保留极高频率
        """
        super(HighDctFrequencyExtractor, self).__init__()
        # 校验alpha有效性(必须在0~1之间,否则无法正确划分高低频)
        if alpha <= 0 or alpha >= 1:
            raise ValueError("alpha must be between 0 and 1 (exclusive)")
        self.alpha = alpha
        # 初始化DCT变换矩阵(高度和宽度方向),None表示动态生成(适配不同输入尺寸)
        self.dct_matrix_h = None  # 高度方向DCT矩阵
        self.dct_matrix_w = None  # 宽度方向DCT矩阵

    def create_dct_matrix(self, N):
        """
        生成N×N的二维DCT变换矩阵(基于DCT-II标准公式)
        Args:
            N (int): DCT矩阵尺寸(对应特征图的高度或宽度)
        Returns:
            torch.Tensor: N×N的DCT矩阵,形状 [N, N]
        """
        # 生成索引矩阵:n为列索引(1×N),k为行索引(N×1)
        n = torch.arange(N, dtype=torch.float32).reshape((1, N))
        k = torch.arange(N, dtype=torch.float32).reshape((N, 1))
        # DCT-II公式:dct_matrix[k][n] = sqrt(2/N) * cos(π*k*(2n+1)/(2N))
        # 首行特殊处理:dct_matrix[0][n] = 1/sqrt(N)(直流分量归一化)
        dct_matrix = torch.sqrt(torch.tensor(2.0 / N)) * torch.cos(math.pi * k * (2 * n + 1) / (2 * N))
        dct_matrix[0, :] = 1 / math.sqrt(N)  # 直流分量行(k=0)单独归一化
        return dct_matrix

    def dct_2d(self, x):
        """
        二维DCT变换:将输入特征从空间域转换到频率域
        Args:
            x (torch.Tensor): 输入特征,形状 [B, C, H, W](B=批量,C=通道,H=高度,W=宽度)
        Returns:
            torch.Tensor: DCT变换后的频率域特征,形状与输入一致 [B, C, H, W]
        """
        H, W = x.size(-2), x.size(-1)  # 提取特征图的高度和宽度
        # 动态生成/更新DCT矩阵(若输入尺寸变化,重新生成对应尺寸的矩阵)
        if self.dct_matrix_h is None or self.dct_matrix_h.size(0) != H:
            self.dct_matrix_h = self.create_dct_matrix(H).to(x.device)
        if self.dct_matrix_w is None or self.dct_matrix_w.size(0) != W:
            self.dct_matrix_w = self.create_dct_matrix(W).to(x.device)

        # 二维DCT计算:x_dct = DCT_h × x × DCT_w^T(矩阵乘法实现空间域→频率域)
        # 先对宽度方向做DCT(x × DCT_w^T),再对高度方向做DCT(DCT_h × 结果)
        return torch.matmul(self.dct_matrix_h, torch.matmul(x, self.dct_matrix_w.t()))

    def idct_2d(self, x):
        """
        二维逆DCT变换:将频率域特征映射回空间域
        Args:
            x (torch.Tensor): 频率域特征,形状 [B, C, H, W]
        Returns:
            torch.Tensor: 逆DCT后的空间域特征,形状与输入一致 [B, C, H, W]
        """
        H, W = x.size(-2), x.size(-1)
        # 动态生成/更新DCT矩阵(与dct_2d共享矩阵,避免重复计算)
        if self.dct_matrix_h is None or self.dct_matrix_h.size(0) != H:
            self.dct_matrix_h = self.create_dct_matrix(H).to(x.device)
        if self.dct_matrix_w is None or self.dct_matrix_w.size(0) != W:
            self.dct_matrix_w = self.create_dct_matrix(W).to(x.device)

        # 二维逆DCT计算:x_idct = DCT_h^T × x × DCT_w(DCT矩阵正交性,逆变换=转置)
        return torch.matmul(self.dct_matrix_h.t(), torch.matmul(x, self.dct_matrix_w))

    def high_pass_filter(self, x, alpha):
        """
        高频滤波:生成频率域掩码,抑制低频区域,保留高频区域
        Args:
            x (torch.Tensor): 频率域特征,形状 [B, C, H, W]
            alpha (float): 低频抑制比例(与__init__中的alpha一致)
        Returns:
            torch.Tensor: 滤波后的频率域特征(低频被置0,高频保留)
        """
        h, w = x.shape[-2:]  # 频率域特征的高度和宽度
        # 生成全1掩码(初始保留所有频率)
        mask = torch.ones(h, w, device=x.device)
        # 计算低频区域大小:前alpha*h行、前alpha*w列(频率域左上角为低频,右下角为高频)
        alpha_h, alpha_w = int(alpha * h), int(alpha * w)
        mask[:alpha_h, :alpha_w] = 0  # 抑制低频区域(置0)
        return x * mask  # 频率域特征与掩码相乘,保留高频

    def forward(self, x):
        """
        前向传播流程:空间域→DCT→高频滤波→逆DCT→空间域归一化→输出
        Args:
            x (torch.Tensor): 输入空间域特征,形状 [B, C, H, W]
        Returns:
            torch.Tensor: 输出高频空间域特征,形状与输入一致 [B, C, H, W],值已归一化到0~1
        """
        # 步骤1:空间域→频率域(DCT变换)
        xq = self.dct_2d(x)
        # 步骤2:频率域滤波(抑制低频,保留高频)
        xq_high = self.high_pass_filter(xq, self.alpha)
        # 步骤3:频率域→空间域(逆DCT变换)
        xh = self.idct_2d(xq_high)

        # 步骤4:特征归一化(将高频特征映射到0~1,避免数值波动影响后续模块)
        B = xh.shape[0]  # 批量大小
        # 计算每个样本的全局最小值和最大值(按通道和空间维度展平)
        min_vals = xh.reshape(B, -1).min(dim=1, keepdim=True).values.view(B, 1, 1, 1)
        max_vals = xh.reshape(B, -1).max(dim=1, keepdim=True).values.view(B, 1, 1, 1)
        #  min-max归一化:(x - min)/(max - min)
        xh = (xh - min_vals) / (max_vals - min_vals)

        return xh


class LowDctFrequencyExtractor(nn.Module):
    """
    低频DCT频率提取模块:通过二维DCT将特征转换到频率域,保留低频分量并抑制高频分量,
    最终映射回空间域并归一化
    核心作用:捕捉特征中的全局趋势(如整体轮廓、亮度分布),适配需要全局信息的任务
    """

    def __init__(self, alpha=0.95):
        """
        Args:
            alpha (float): 高频抑制比例(0<alpha<1),默认0.95表示保留前95%的低频区域
                           alpha越大,保留的低频区域越大;alpha越小,仅保留极低频率
        """
        super(LowDctFrequencyExtractor, self).__init__()
        if alpha <= 0 or alpha >= 1:
            raise ValueError("alpha must be between 0 and 1 (exclusive)")
        self.alpha = alpha
        self.dct_matrix_h = None  # 高度方向DCT矩阵(动态生成)
        self.dct_matrix_w = None  # 宽度方向DCT矩阵(动态生成)

    def create_dct_matrix(self, N):
        """生成N×N的DCT矩阵,与HighDctFrequencyExtractor完全一致,复用DCT-II公式"""
        n = torch.arange(N, dtype=torch.float32).reshape((1, N))
        k = torch.arange(N, dtype=torch.float32).reshape((N, 1))
        dct_matrix = torch.sqrt(torch.tensor(2.0 / N)) * torch.cos(math.pi * k * (2 * n + 1) / (2 * N))
        dct_matrix[0, :] = 1 / math.sqrt(N)
        return dct_matrix

    def dct_2d(self, x):
        """二维DCT变换,与HighDctFrequencyExtractor完全一致,空间域→频率域"""
        H, W = x.size(-2), x.size(-1)
        if self.dct_matrix_h is None or self.dct_matrix_h.size(0) != H:
            self.dct_matrix_h = self.create_dct_matrix(H).to(x.device)
        if self.dct_matrix_w is None or self.dct_matrix_w.size(0) != W:
            self.dct_matrix_w = self.create_dct_matrix(W).to(x.device)
        return torch.matmul(self.dct_matrix_h, torch.matmul(x, self.dct_matrix_w.t()))

    def idct_2d(self, x):
        """二维逆DCT变换,与HighDctFrequencyExtractor完全一致,频率域→空间域"""
        H, W = x.size(-2), x.size(-1)
        if self.dct_matrix_h is None or self.dct_matrix_h.size(0) != H:
            self.dct_matrix_h = self.create_dct_matrix(H).to(x.device)
        if self.dct_matrix_w is None or self.dct_matrix_w.size(0) != W:
            self.dct_matrix_w = self.create_dct_matrix(W).to(x.device)
        return torch.matmul(self.dct_matrix_h.t(), torch.matmul(x, self.dct_matrix_w))

    def high_pass_filter(self, x, alpha):
        """
        低频滤波:生成频率域掩码,抑制高频区域,保留低频区域
        与HighDctFrequencyExtractor的滤波逻辑相反:抑制右下角高频,保留左上角低频
        Args:
            x (torch.Tensor): 频率域特征,形状 [B, C, H, W]
            alpha (float): 高频抑制比例(与__init__中的alpha一致)
        Returns:
            torch.Tensor: 滤波后的频率域特征(高频被置0,低频保留)
        """
        h, w = x.shape[-2:]
        mask = torch.ones(h, w, device=x.device)
        # 计算高频区域大小:后(1-alpha)*h行、后(1-alpha)*w列(频率域右下角为高频)
        alpha_h, alpha_w = int(alpha * h), int(alpha * w)
        mask[-alpha_h:, -alpha_w:] = 0  # 抑制高频区域(置0)
        return x * mask

    def forward(self, x):
        """
        前向传播流程:空间域→DCT→低频滤波→逆DCT→空间域归一化→输出
        Args:
            x (torch.Tensor): 输入空间域特征,形状 [B, C, H, W]
        Returns:
            torch.Tensor: 输出低频空间域特征,形状与输入一致 [B, C, H, W],值已归一化到0~1
        """
        # 步骤1:空间域→频率域(DCT变换)
        xq = self.dct_2d(x)
        # 步骤2:频率域滤波(抑制高频,保留低频)
        xq_high = self.high_pass_filter(xq, self.alpha)
        # 步骤3:频率域→空间域(逆DCT变换)
        xh = self.idct_2d(xq_high)

        # 步骤4:特征归一化(与HighDctFrequencyExtractor一致,映射到0~1)
        B = xh.shape[0]
        min_vals = xh.reshape(B, -1).min(dim=1, keepdim=True).values.view(B, 1, 1, 1)
        max_vals = xh.reshape(B, -1).max(dim=1, keepdim=True).values.view(B, 1, 1, 1)
        xh = (xh - min_vals) / (max_vals - min_vals)

        return xh

if __name__ == "__main__":
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    x = torch.randn(1, 64, 32, 32).to(device)
    model1 = HighDctFrequencyExtractor()
    model2 = LowDctFrequencyExtractor()

    model1.to(device)
    model2.to(device)

    y1 = model1(x)
    y2 = model2(x)

    print("微信公众号:十小大的底层视觉工坊")
    print("知乎、CSDN:十小大")

    print("输入特征维度:", x.shape)
    print("高频输出特征维度:", y1.shape)
    print("低频输出特征维度:", y2.shape)

运行结果:

在这里插入图片描述


至此本文结束。

如果本文对你有所帮助,请点赞收藏,创作不易,感谢您的支持!

点击下方👇公众号区域,扫码关注,可免费领取一份200+即插即用模块资料

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐