【即插即用模块】Transformer篇 | AAAI 2025 | Mesorch:自适应DCT调控,高低频精准分离,特征完美捕获,涨点起飞!
【摘要】Mesorch模块创新性结合Transformer与CNN,通过自适应频率划分(高频alpha=0.05保留95%高频)、动态DCT矩阵和全局归一化,解决传统频率域方法的三大痛点。该模块在图像篡改定位任务中表现优异,消融实验显示CNN+高频DCT、Transformer+低频DCT的组合最有效。核心代码实现二维DCT变换与逆变换,支持动态尺寸输入,特征流维度保持[B,C,H,W]不变。模块
VX: shixiaodayyds,备注【即插即用】,添加即插即用模块交流群。
模块出处

Code:https://github.com/scu-zjz/Mesorch
模块介绍

Mesoscopic-Orchestration(Mesorch):
- 将 Transformer 和 CNN 并行结合,Transformers 提取宏观信息和 CNN 捕获微观细节。
- 在不同尺度上探索,无缝评估微观和宏观信息。
现有频率域方法(如直接 FFT/DCT 后丢弃高频)存在不足:
- 固定频率划分:手动选择频率阈值(如仅保留前 10% 低频),无法自适应不同任务需求(如分割需更多高频,分类需更多低频);
- 缺乏归一化:频率域特征数值范围波动大(从 - 1e3 到 1e3),直接输入后续模块易导致梯度爆炸;
- 无动态适配:DCT 矩阵固定为单一尺寸,无法适配不同输入分辨率(如 32×32、64×64 特征图)。
Mesorch 通过三大创新解决上述问题:
- 自适应频率划分:通过alpha参数灵活控制高低频比例(如高频alpha=0.05保留 95% 高频,低频alpha=0.95保留 95% 低频),适配不同任务;
- 动态 DCT 矩阵:根据输入特征尺寸(H/W)实时生成对应大小的 DCT 矩阵,支持任意正方形特征图;
- 全局归一化:将逆 DCT 后的特征映射到 0~1,避免数值波动,提升后续模块稳定性。
模块提出的动机(Motivation)
介观水平是宏观世界和微观世界之间的桥梁,解决了两者忽略的差距。图像处理定位(IML)是一种从假图像中追求真理的关键技术,长期以来一直依赖于低级(微观层面的)痕迹。然而,在实践中,大多数篡改旨在通过改变图像语义来欺骗受众。因此,操作通常发生在对象级别(宏观级别),这与微观跟踪同样重要。因此,将这两个级别集成到介观级别为 IML 研究提供了新的视角。
因此,引入Mesorach架构来协调两者,解决传统视觉模型在 “细节捕捉” 与 “全局建模” 上的矛盾。
适用范围与模块效果
适用范围:适用于通用视觉领域,特别是需要高低频分离,频率差距明显的数据集上的任务。
模块优劣:
缝合位置:需要特征提取的位置。
模块效果:复杂度更低,架构更合理,性能更优。

消融:CNN对应高频DCT,Transformer对应低频DCT更优。这与CNN和Transformer本身的特性对应。
模块代码及使用方式
代码与模块结构图对应关系:

模块代码(详细注释与特征流前向传播过程中的维度变化):
import torch
import torch.nn as nn
import torch.fft
import math
class HighDctFrequencyExtractor(nn.Module):
"""
高频DCT频率提取模块:通过二维离散余弦变换(DCT)将特征从空间域转换到频率域,
保留高频分量并抑制低频分量,最终将高频特征映射回空间域并归一化
核心作用:捕捉特征中的细节信息(如边缘、纹理),适配需要精细特征的任务
"""
def __init__(self, alpha=0.05):
"""
Args:
alpha (float): 低频抑制比例(0<alpha<1),默认0.05表示抑制前5%的低频区域
alpha越小,保留的高频区域越大;alpha越大,仅保留极高频率
"""
super(HighDctFrequencyExtractor, self).__init__()
# 校验alpha有效性(必须在0~1之间,否则无法正确划分高低频)
if alpha <= 0 or alpha >= 1:
raise ValueError("alpha must be between 0 and 1 (exclusive)")
self.alpha = alpha
# 初始化DCT变换矩阵(高度和宽度方向),None表示动态生成(适配不同输入尺寸)
self.dct_matrix_h = None # 高度方向DCT矩阵
self.dct_matrix_w = None # 宽度方向DCT矩阵
def create_dct_matrix(self, N):
"""
生成N×N的二维DCT变换矩阵(基于DCT-II标准公式)
Args:
N (int): DCT矩阵尺寸(对应特征图的高度或宽度)
Returns:
torch.Tensor: N×N的DCT矩阵,形状 [N, N]
"""
# 生成索引矩阵:n为列索引(1×N),k为行索引(N×1)
n = torch.arange(N, dtype=torch.float32).reshape((1, N))
k = torch.arange(N, dtype=torch.float32).reshape((N, 1))
# DCT-II公式:dct_matrix[k][n] = sqrt(2/N) * cos(π*k*(2n+1)/(2N))
# 首行特殊处理:dct_matrix[0][n] = 1/sqrt(N)(直流分量归一化)
dct_matrix = torch.sqrt(torch.tensor(2.0 / N)) * torch.cos(math.pi * k * (2 * n + 1) / (2 * N))
dct_matrix[0, :] = 1 / math.sqrt(N) # 直流分量行(k=0)单独归一化
return dct_matrix
def dct_2d(self, x):
"""
二维DCT变换:将输入特征从空间域转换到频率域
Args:
x (torch.Tensor): 输入特征,形状 [B, C, H, W](B=批量,C=通道,H=高度,W=宽度)
Returns:
torch.Tensor: DCT变换后的频率域特征,形状与输入一致 [B, C, H, W]
"""
H, W = x.size(-2), x.size(-1) # 提取特征图的高度和宽度
# 动态生成/更新DCT矩阵(若输入尺寸变化,重新生成对应尺寸的矩阵)
if self.dct_matrix_h is None or self.dct_matrix_h.size(0) != H:
self.dct_matrix_h = self.create_dct_matrix(H).to(x.device)
if self.dct_matrix_w is None or self.dct_matrix_w.size(0) != W:
self.dct_matrix_w = self.create_dct_matrix(W).to(x.device)
# 二维DCT计算:x_dct = DCT_h × x × DCT_w^T(矩阵乘法实现空间域→频率域)
# 先对宽度方向做DCT(x × DCT_w^T),再对高度方向做DCT(DCT_h × 结果)
return torch.matmul(self.dct_matrix_h, torch.matmul(x, self.dct_matrix_w.t()))
def idct_2d(self, x):
"""
二维逆DCT变换:将频率域特征映射回空间域
Args:
x (torch.Tensor): 频率域特征,形状 [B, C, H, W]
Returns:
torch.Tensor: 逆DCT后的空间域特征,形状与输入一致 [B, C, H, W]
"""
H, W = x.size(-2), x.size(-1)
# 动态生成/更新DCT矩阵(与dct_2d共享矩阵,避免重复计算)
if self.dct_matrix_h is None or self.dct_matrix_h.size(0) != H:
self.dct_matrix_h = self.create_dct_matrix(H).to(x.device)
if self.dct_matrix_w is None or self.dct_matrix_w.size(0) != W:
self.dct_matrix_w = self.create_dct_matrix(W).to(x.device)
# 二维逆DCT计算:x_idct = DCT_h^T × x × DCT_w(DCT矩阵正交性,逆变换=转置)
return torch.matmul(self.dct_matrix_h.t(), torch.matmul(x, self.dct_matrix_w))
def high_pass_filter(self, x, alpha):
"""
高频滤波:生成频率域掩码,抑制低频区域,保留高频区域
Args:
x (torch.Tensor): 频率域特征,形状 [B, C, H, W]
alpha (float): 低频抑制比例(与__init__中的alpha一致)
Returns:
torch.Tensor: 滤波后的频率域特征(低频被置0,高频保留)
"""
h, w = x.shape[-2:] # 频率域特征的高度和宽度
# 生成全1掩码(初始保留所有频率)
mask = torch.ones(h, w, device=x.device)
# 计算低频区域大小:前alpha*h行、前alpha*w列(频率域左上角为低频,右下角为高频)
alpha_h, alpha_w = int(alpha * h), int(alpha * w)
mask[:alpha_h, :alpha_w] = 0 # 抑制低频区域(置0)
return x * mask # 频率域特征与掩码相乘,保留高频
def forward(self, x):
"""
前向传播流程:空间域→DCT→高频滤波→逆DCT→空间域归一化→输出
Args:
x (torch.Tensor): 输入空间域特征,形状 [B, C, H, W]
Returns:
torch.Tensor: 输出高频空间域特征,形状与输入一致 [B, C, H, W],值已归一化到0~1
"""
# 步骤1:空间域→频率域(DCT变换)
xq = self.dct_2d(x)
# 步骤2:频率域滤波(抑制低频,保留高频)
xq_high = self.high_pass_filter(xq, self.alpha)
# 步骤3:频率域→空间域(逆DCT变换)
xh = self.idct_2d(xq_high)
# 步骤4:特征归一化(将高频特征映射到0~1,避免数值波动影响后续模块)
B = xh.shape[0] # 批量大小
# 计算每个样本的全局最小值和最大值(按通道和空间维度展平)
min_vals = xh.reshape(B, -1).min(dim=1, keepdim=True).values.view(B, 1, 1, 1)
max_vals = xh.reshape(B, -1).max(dim=1, keepdim=True).values.view(B, 1, 1, 1)
# min-max归一化:(x - min)/(max - min)
xh = (xh - min_vals) / (max_vals - min_vals)
return xh
class LowDctFrequencyExtractor(nn.Module):
"""
低频DCT频率提取模块:通过二维DCT将特征转换到频率域,保留低频分量并抑制高频分量,
最终映射回空间域并归一化
核心作用:捕捉特征中的全局趋势(如整体轮廓、亮度分布),适配需要全局信息的任务
"""
def __init__(self, alpha=0.95):
"""
Args:
alpha (float): 高频抑制比例(0<alpha<1),默认0.95表示保留前95%的低频区域
alpha越大,保留的低频区域越大;alpha越小,仅保留极低频率
"""
super(LowDctFrequencyExtractor, self).__init__()
if alpha <= 0 or alpha >= 1:
raise ValueError("alpha must be between 0 and 1 (exclusive)")
self.alpha = alpha
self.dct_matrix_h = None # 高度方向DCT矩阵(动态生成)
self.dct_matrix_w = None # 宽度方向DCT矩阵(动态生成)
def create_dct_matrix(self, N):
"""生成N×N的DCT矩阵,与HighDctFrequencyExtractor完全一致,复用DCT-II公式"""
n = torch.arange(N, dtype=torch.float32).reshape((1, N))
k = torch.arange(N, dtype=torch.float32).reshape((N, 1))
dct_matrix = torch.sqrt(torch.tensor(2.0 / N)) * torch.cos(math.pi * k * (2 * n + 1) / (2 * N))
dct_matrix[0, :] = 1 / math.sqrt(N)
return dct_matrix
def dct_2d(self, x):
"""二维DCT变换,与HighDctFrequencyExtractor完全一致,空间域→频率域"""
H, W = x.size(-2), x.size(-1)
if self.dct_matrix_h is None or self.dct_matrix_h.size(0) != H:
self.dct_matrix_h = self.create_dct_matrix(H).to(x.device)
if self.dct_matrix_w is None or self.dct_matrix_w.size(0) != W:
self.dct_matrix_w = self.create_dct_matrix(W).to(x.device)
return torch.matmul(self.dct_matrix_h, torch.matmul(x, self.dct_matrix_w.t()))
def idct_2d(self, x):
"""二维逆DCT变换,与HighDctFrequencyExtractor完全一致,频率域→空间域"""
H, W = x.size(-2), x.size(-1)
if self.dct_matrix_h is None or self.dct_matrix_h.size(0) != H:
self.dct_matrix_h = self.create_dct_matrix(H).to(x.device)
if self.dct_matrix_w is None or self.dct_matrix_w.size(0) != W:
self.dct_matrix_w = self.create_dct_matrix(W).to(x.device)
return torch.matmul(self.dct_matrix_h.t(), torch.matmul(x, self.dct_matrix_w))
def high_pass_filter(self, x, alpha):
"""
低频滤波:生成频率域掩码,抑制高频区域,保留低频区域
与HighDctFrequencyExtractor的滤波逻辑相反:抑制右下角高频,保留左上角低频
Args:
x (torch.Tensor): 频率域特征,形状 [B, C, H, W]
alpha (float): 高频抑制比例(与__init__中的alpha一致)
Returns:
torch.Tensor: 滤波后的频率域特征(高频被置0,低频保留)
"""
h, w = x.shape[-2:]
mask = torch.ones(h, w, device=x.device)
# 计算高频区域大小:后(1-alpha)*h行、后(1-alpha)*w列(频率域右下角为高频)
alpha_h, alpha_w = int(alpha * h), int(alpha * w)
mask[-alpha_h:, -alpha_w:] = 0 # 抑制高频区域(置0)
return x * mask
def forward(self, x):
"""
前向传播流程:空间域→DCT→低频滤波→逆DCT→空间域归一化→输出
Args:
x (torch.Tensor): 输入空间域特征,形状 [B, C, H, W]
Returns:
torch.Tensor: 输出低频空间域特征,形状与输入一致 [B, C, H, W],值已归一化到0~1
"""
# 步骤1:空间域→频率域(DCT变换)
xq = self.dct_2d(x)
# 步骤2:频率域滤波(抑制高频,保留低频)
xq_high = self.high_pass_filter(xq, self.alpha)
# 步骤3:频率域→空间域(逆DCT变换)
xh = self.idct_2d(xq_high)
# 步骤4:特征归一化(与HighDctFrequencyExtractor一致,映射到0~1)
B = xh.shape[0]
min_vals = xh.reshape(B, -1).min(dim=1, keepdim=True).values.view(B, 1, 1, 1)
max_vals = xh.reshape(B, -1).max(dim=1, keepdim=True).values.view(B, 1, 1, 1)
xh = (xh - min_vals) / (max_vals - min_vals)
return xh
if __name__ == "__main__":
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
x = torch.randn(1, 64, 32, 32).to(device)
model1 = HighDctFrequencyExtractor()
model2 = LowDctFrequencyExtractor()
model1.to(device)
model2.to(device)
y1 = model1(x)
y2 = model2(x)
print("微信公众号:十小大的底层视觉工坊")
print("知乎、CSDN:十小大")
print("输入特征维度:", x.shape)
print("高频输出特征维度:", y1.shape)
print("低频输出特征维度:", y2.shape)
运行结果:

至此本文结束。
如果本文对你有所帮助,请点赞收藏,创作不易,感谢您的支持!
点击下方👇公众号区域,扫码关注,可免费领取一份200+即插即用模块资料!
更多推荐



所有评论(0)