🔥 AI 即插即用 | 你的CV涨点模块“军火库”已开源!🔥

为了方便大家在CV科研和项目中高效涨点,我创建并维护了一个即插即用模块的GitHub代码仓库。

仓库里不仅有:

  • 核心模块即插即用代码
  • 论文精读总结
  • 架构图深度解析

更有海量SOTA模型的创新模块汇总,致力于打造一个“AI即插即用”的百宝箱,方便大家快速实验、组合创新!

🚀 GitHub 仓库链接https://github.com/AITricks/AITricks

觉得有帮助的话,欢迎大家 Star, Fork, PR 一键三连,共同维护!

即插即用涨点系列(十五)AAAI 2025 SOTA | ConDSeg:基于“语义信息解耦”与“对比驱动聚合”的通用医学图像分割新标杆 (含原理+代码)

论文原文 (Paper)https://arxiv.org/abs/2412.08345
官方代码 (Code)https://github.com/Mengqi-Lei/ConDSeg

ConDSeg:基于对比驱动特征增强的通用医学图像分割框架

1. 核心思想

ConDSeg 针对医学图像分割中普遍存在的“边界模糊”和“共现误导”两大挑战,提出了一种名为 对比驱动(Contrast-Driven) 的通用分割框架。该框架通过两阶段策略:第一阶段利用 一致性增强(Consistency Reinforcement, CR) 提升编码器在恶劣环境下的鲁棒性;第二阶段通过 语义信息解耦(SID)对比驱动特征聚合(CDFA) 模块,利用前景与背景的对比信息来引导特征融合。同时,引入 尺寸感知解码器(SA-Decoder) 来分别处理不同尺度的目标,从而有效解决了共现现象带来的特征混淆问题,在多个医学影像数据集上实现了 SOTA 性能。

2. 背景与动机

  • 文本角度总结
    医学图像分割是辅助临床诊断的关键,但受限于成像原理,医学图像常面临光照差、对比度低等问题,导致前景与背景之间存在“软边界”(soft boundary),难以区分。此外,医学图像中器官或病灶(如息肉)往往呈现固定的共现模式(例如大息肉旁常伴随小息肉),这种规律性容易误导模型学习到错误的共现特征,导致在单发病灶时产生虚警(预测出不存在的目标)。现有的方法虽然引入了边界监督,但并未从根本上提升模型在不确定区域的判别能力,也忽视了共现现象的干扰。

  • 动机图解分析

    • 图 1(Figure 1):挑战展示

      • 左图(模糊边界):自然图像(上)中狗与草地的边界清晰;而医学图像(中、下)中,息肉或病灶与正常组织的边界模糊不清,且受到低光照和低对比度的严重影响,这就是“软边界”问题。
      • 右图(共现挑战):展示了共现(Co-occurrence)和单发(Single occurrence)两种情况。在共现图中,大小病灶同时出现;但在单发图中,模型(如 TGANet)受共现先验误导,在仅有一个病灶时错误地预测出了额外的虚假病灶。ConDSeg 的动机正是为了解决这种因过度依赖上下文共现而导致的误判。
    • 图 3(Figure 3):Grad-CAM 可视化对比

      • TGANet(第三列):在“单发”情况下,TGANet 的热力图不仅覆盖了真实的病灶,还在周围空白区域产生了高响应,这证实了现有模型容易受共现特征误导。
      • ConDSeg(第四列):ConDSeg 的热力图精准地集中在真实病灶上,没有被共现模式误导,证明了其解决共现问题的有效性。

3. 主要贡献点

  • [贡献点 1]:提出了一致性增强(CR)训练策略
    针对光照差、对比度低的问题,设计了一种预训练策略。通过强制编码器对原始图像和强增强(改变光照、颜色等)后的图像输出一致的预测,迫使编码器学习对环境变化鲁棒的高质量特征,从源头提升了特征提取能力。

  • [贡献点 2]:设计了语义信息解耦(SID)与对比驱动特征聚合(CDFA)模块
    SID 将深层特征解耦为前景、背景和不确定区域三部分,并通过专门的损失函数逐步压缩不确定区域。CDFA 则利用解耦出的前景和背景特征作为“对比线索”,指导浅层特征的融合与增强,使得模型能更敏锐地分辨边界。

  • [贡献点 3]:引入了尺寸感知解码器(SA-Decoder)
    针对共现问题,设计了多路解码器结构。不同尺度的解码器分别负责预测小、中、大目标,利用不同层级的特征(浅层细节对应小目标,深层语义对应大目标),从而避免了模型因混淆不同尺寸目标而产生的错误共现联想。

4. 方法细节(最重要)

  • 整体网络架构(对应 Figure 2)

    在这里插入图片描述

    • 阶段一(Stage I):预训练阶段。输入图像经过强数据增强,与原图一起送入共享权重的 Encoder。通过计算两者预测掩码的一致性损失( L c o n s \mathcal{L}_{cons} Lcons),训练 Encoder 对恶劣环境的鲁棒性。
    • 阶段二(Stage II):正式分割阶段。
      1. 编码器:ResNet-50 提取四层特征 f 1 , f 2 , f 3 , f 4 f_1, f_2, f_3, f_4 f1,f2,f3,f4
      2. SID 模块:最深层特征 f 4 f_4 f4 进入 SID,解耦出前景 f f g f_{fg} ffg、背景 f b g f_{bg} fbg 和不确定 f u c f_{uc} fuc 特征。
      3. CDFA 模块:特征 f f g f_{fg} ffg f b g f_{bg} fbg 被送入各级 CDFA 模块。CDFA 接收上一级的输出和当前级的 Encoder 特征,利用前景背景的对比信息进行加权融合,输出增强特征 F ~ 1 , F ~ 2 , F ~ 3 , F ~ 4 \tilde{F}_1, \tilde{F}_2, \tilde{F}_3, \tilde{F}_4 F~1,F~2,F~3,F~4
      4. SA-Decoder:增强特征被分配给三个并行的解码器(Small, Medium, Large)。例如, F ~ 1 , F ~ 2 \tilde{F}_1, \tilde{F}_2 F~1,F~2 喂给 Decoder_Small, F ~ 3 , F ~ 4 \tilde{F}_3, \tilde{F}_4 F~3,F~4 喂给 Decoder_Large。
      5. 输出:三个解码器的预测结果拼接后通过 Sigmoid 生成最终掩码。
  • 核心创新模块详解

    • 模块 A:语义信息解耦 (SID) 模块

      • 内部结构:包含三个并行分支,分别对应前景、背景、不确定区域。
      • 数据流:输入深层特征 f 4 f_4 f4,经过 3 × 3 3\times3 3×3 卷积和 1 × 1 1\times1 1×1 卷积后,分别生成特征图 f f g , f b g , f u c f_{fg}, f_{bg}, f_{uc} ffg,fbg,fuc
      • 辅助监督:这些特征图通过辅助头(Auxiliary Head)生成对应的概率掩码 M f g , M b g , M u c M^{fg}, M^{bg}, M^{uc} Mfg,Mbg,Muc
      • 设计目的:通过互补损失 L c o m p l \mathcal{L}_{compl} Lcompl 强制三个掩码之和为 1(即每个像素必须归属一类),并利用加权损失迫使 M u c M^{uc} Muc 区域最小化,从而在特征层面将模糊的边界“逼”向确定的前景或背景。
    • 模块 B:对比驱动特征聚合 (CDFA) 模块(对应 Figure 4)

      在这里插入图片描述

      • 内部结构:基于注意力机制的特征融合单元。
      • 数据流
        1. 输入:主输入是待融合特征 F F F(来自上一级或 Encoder),引导输入是 f f g f_{fg} ffg f b g f_{bg} fbg
        2. 值生成:输入 F F F 经过卷积生成 Value 向量,并按 K × K K\times K K×K 窗口展开(Unfold)。
        3. 权重生成 f f g f_{fg} ffg f b g f_{bg} fbg 分别通过线性层生成前景注意力权重 A f g A_{fg} Afg 和背景注意力权重 A b g A_{bg} Abg
        4. 双重加权:Value 向量先被 A b g A_{bg} Abg 加权(抑制背景),再被 A f g A_{fg} Afg 加权(增强前景)。公式为: V ~ = Softmax ( A f g ) ⊗ ( Softmax ( A b g ) ⊗ V ) \tilde{V} = \text{Softmax}(A_{fg}) \otimes (\text{Softmax}(A_{bg}) \otimes V) V~=Softmax(Afg)(Softmax(Abg)V)
        5. 重构:加权后的局部窗口特征被折叠(Fold/Aggregate)回特征图尺寸,输出 F ~ \tilde{F} F~
      • 设计理念:利用深层确定的语义信息(前景/背景)作为“探针”,去浅层特征中“捞取”相关的细节,同时抑制无关的噪声。
    • 模块 C:尺寸感知解码器 (SA-Decoder)(对应 Figure 10)

      • 内部结构:三个独立的 U-Net 风格解码器。
      • 数据流
        • Decoder_Small:接收浅层高分辨率特征(如 f 1 , f 2 f_1, f_2 f1,f2),专注小目标。
        • Decoder_Medium:接收中层特征。
        • Decoder_Large:接收深层低分辨率特征(如 f 3 , f 4 f_3, f_4 f3,f4),专注大目标。
      • 设计理念:强制模型在不同的尺度空间分别寻找目标,打破了模型对“大目标旁必有小目标”这种单一尺度共现特征的依赖。
  • 理念与机制总结
    ConDSeg 的核心理念是 “对比与分治”

    • 对比:通过 SID 将特征二分为“黑(背景)白(前景)”和“灰(不确定)”,并在 CDFA 中利用“黑白”特征的对比来消除“灰”色区域的歧义,从而解决软边界问题。
    • 分治:通过 SA-Decoder 将不同大小的目标分配给不同的解码器负责,使得模型在处理单发或共现目标时,都能独立地做出判断,不再受共现先验的干扰。
  • 图解总结

    • Figure 2 展示了从训练策略(左)到特征解耦(中)再到特征融合与分尺度解码(右)的完整流线。
    • Figure 3 的 Grad-CAM 热力图直观证明了 ConDSeg 相比 TGANet 更能聚焦于病灶本体,消除了共现带来的虚警。
    • Figure 8 的 t-SNE 可视化进一步证实了 SID 模块成功将特征空间分离为了清晰的前景、背景和逐渐缩小的不确定区域。

5. 即插即用模块的作用

ConDSeg 中的创新模块设计独立性较好,可迁移至其他分割任务:

  1. 一致性增强(CR)策略

    • 适用场景:所有需要应对光照变化、低对比度或数据稀缺的分割/检测任务。
    • 具体应用:可以作为一个通用的 预训练(Pre-training) 步骤。在正式训练任何分割网络(如 U-Net, DeepLab)之前,先用 CR 策略在未标注或已标注数据上预训练 Encoder,可显著提升模型在恶劣成像条件下的鲁棒性。
  2. 语义信息解耦(SID)模块

    • 适用场景:边界模糊、存在“过渡带”的分割任务(如伪装目标检测、云层分割)。
    • 具体应用:可作为一个 辅助监督头(Auxiliary Head) 插入到任何 Encoder 的末端。通过引入 f u c f_{uc} fuc(不确定区域)并最小化其范围,可以强迫主干网络学习更清晰的边界特征。
  3. 对比驱动特征聚合(CDFA)模块

    • 适用场景:需要多级特征融合的密集预测任务。
    • 具体应用:可替代 FPN 或 U-Net 中的 Skip Connection(跳跃连接)。传统的跳跃连接是直接拼接(Concat),而 CDFA 可以利用深层语义特征作为 Query,对浅层细节特征进行“清洗”和“筛选”,从而减少浅层噪声的干扰。

6.即插即用模块

"""
即插即用模块集合 (Plug-and-Play Modules)
从ConDSeg模型中提取的可复用模块,可以用于不同的backbone架构
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# ==================== 基础模块 ====================

class CBR(nn.Module):
    """基础卷积块: Conv + BatchNorm + ReLU"""
    def __init__(self, in_c, out_c, kernel_size=3, padding=1, dilation=1, stride=1, act=True):
        super().__init__()
        self.act = act

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size, padding=padding, dilation=dilation, bias=False, stride=stride),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.act == True:
            x = self.relu(x)
        return x


class channel_attention(nn.Module):
    """通道注意力机制"""
    def __init__(self, in_planes, ratio=16):
        super(channel_attention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x0 = x
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return x0 * self.sigmoid(out)


class spatial_attention(nn.Module):
    """空间注意力机制"""
    def __init__(self, kernel_size=7):
        super(spatial_attention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x0 = x  # [B,C,H,W]
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return x0 * self.sigmoid(x)


# ==================== 特征增强模块 ====================

class dilated_conv(nn.Module):
    """空洞卷积模块 (FEM - Feature Enhancement Module)"""
    def __init__(self, in_c, out_c):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        self.c1 = nn.Sequential(CBR(in_c, out_c, kernel_size=1, padding=0), channel_attention(out_c))
        self.c2 = nn.Sequential(CBR(in_c, out_c, kernel_size=(3, 3), padding=6, dilation=6), channel_attention(out_c))
        self.c3 = nn.Sequential(CBR(in_c, out_c, kernel_size=(3, 3), padding=12, dilation=12), channel_attention(out_c))
        self.c4 = nn.Sequential(CBR(in_c, out_c, kernel_size=(3, 3), padding=18, dilation=18), channel_attention(out_c))
        self.c5 = CBR(out_c * 4, out_c, kernel_size=3, padding=1, act=False)
        self.c6 = CBR(in_c, out_c, kernel_size=1, padding=0, act=False)
        self.sa = spatial_attention()

    def forward(self, x):
        x1 = self.c1(x)
        x2 = self.c2(x)
        x3 = self.c3(x)
        x4 = self.c4(x)
        xc = torch.cat([x1, x2, x3, x4], axis=1)
        xc = self.c5(xc)
        xs = self.c6(x)
        x = self.relu(xc + xs)
        x = self.sa(x)
        return x


# ==================== 特征解耦模块 ====================

class DecoupleLayer(nn.Module):
    """特征解耦层:将特征分解为前景、背景和不确定性特征"""
    def __init__(self, in_c=1024, out_c=256):
        super(DecoupleLayer, self).__init__()
        self.cbr_fg = nn.Sequential(
            CBR(in_c, 512, kernel_size=3, padding=1),
            CBR(512, out_c, kernel_size=3, padding=1),
            CBR(out_c, out_c, kernel_size=1, padding=0)
        )
        self.cbr_bg = nn.Sequential(
            CBR(in_c, 512, kernel_size=3, padding=1),
            CBR(512, out_c, kernel_size=3, padding=1),
            CBR(out_c, out_c, kernel_size=1, padding=0)
        )
        self.cbr_uc = nn.Sequential(
            CBR(in_c, 512, kernel_size=3, padding=1),
            CBR(512, out_c, kernel_size=3, padding=1),
            CBR(out_c, out_c, kernel_size=1, padding=0)
        )

    def forward(self, x):
        f_fg = self.cbr_fg(x)  # 前景特征
        f_bg = self.cbr_bg(x)  # 背景特征
        f_uc = self.cbr_uc(x)  # 不确定性特征
        return f_fg, f_bg, f_uc


# ==================== 辅助头模块 ====================

class AuxiliaryHead(nn.Module):
    """辅助预测头:生成前景、背景和不确定性的辅助预测"""
    def __init__(self, in_c):
        super(AuxiliaryHead, self).__init__()
        self.branch_fg = nn.Sequential(
            CBR(in_c, 256, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1/8
            CBR(256, 256, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1/4
            CBR(256, 128, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1/2
            CBR(128, 64, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1
            CBR(64, 64, kernel_size=3, padding=1),
            nn.Conv2d(64, 1, kernel_size=1, padding=0),
            nn.Sigmoid()
        )
        self.branch_bg = nn.Sequential(
            CBR(in_c, 256, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1/8
            CBR(256, 256, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1/4
            CBR(256, 128, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1/2
            CBR(128, 64, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1
            CBR(64, 64, kernel_size=3, padding=1),
            nn.Conv2d(64, 1, kernel_size=1, padding=0),
            nn.Sigmoid()
        )
        self.branch_uc = nn.Sequential(
            CBR(in_c, 256, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1/8
            CBR(256, 256, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1/4
            CBR(256, 128, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1/2
            CBR(128, 64, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),  # 1
            CBR(64, 64, kernel_size=3, padding=1),
            nn.Conv2d(64, 1, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, f_fg, f_bg, f_uc):
        mask_fg = self.branch_fg(f_fg)
        mask_bg = self.branch_bg(f_bg)
        mask_uc = self.branch_uc(f_uc)
        return mask_fg, mask_bg, mask_uc


# ==================== CDFA核心模块 ====================

class ContrastDrivenFeatureAggregation(nn.Module):
    """
    对比驱动特征聚合模块 (CDFA - Contrast-Driven Feature Aggregation)
    这是核心的即插即用模块,可以用于任何backbone架构
    """
    def __init__(self, in_c, dim, num_heads, kernel_size=3, padding=1, stride=1,
                 attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.head_dim = dim // num_heads

        self.scale = self.head_dim ** -0.5

        self.v = nn.Linear(dim, dim)
        self.attn_fg = nn.Linear(dim, kernel_size ** 4 * num_heads)
        self.attn_bg = nn.Linear(dim, kernel_size ** 4 * num_heads)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)
        self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)

        self.input_cbr = nn.Sequential(
            CBR(in_c, dim, kernel_size=3, padding=1),
            CBR(dim, dim, kernel_size=3, padding=1),
        )
        self.output_cbr = nn.Sequential(
            CBR(dim, dim, kernel_size=3, padding=1),
            CBR(dim, dim, kernel_size=3, padding=1),
        )

    def forward(self, x, fg, bg):
        """
        Args:
            x: 主特征图 [B, C, H, W]
            fg: 前景特征 [B, C, H, W]
            bg: 背景特征 [B, C, H, W]
        Returns:
            out: 增强后的特征 [B, C, H, W]
        """
        x = self.input_cbr(x)

        x = x.permute(0, 2, 3, 1)
        fg = fg.permute(0, 2, 3, 1)
        bg = bg.permute(0, 2, 3, 1)

        B, H, W, C = x.shape

        v = self.v(x).permute(0, 3, 1, 2)

        v_unfolded = self.unfold(v).reshape(B, self.num_heads, self.head_dim,
                                            self.kernel_size * self.kernel_size,
                                            -1).permute(0, 1, 4, 3, 2)
        attn_fg = self.compute_attention(fg, B, H, W, C, 'fg')

        x_weighted_fg = self.apply_attention(attn_fg, v_unfolded, B, H, W, C)

        v_unfolded_bg = self.unfold(x_weighted_fg.permute(0, 3, 1, 2)).reshape(B, self.num_heads, self.head_dim,
                                                                               self.kernel_size * self.kernel_size,
                                                                               -1).permute(0, 1, 4, 3, 2)
        attn_bg = self.compute_attention(bg, B, H, W, C, 'bg')

        x_weighted_bg = self.apply_attention(attn_bg, v_unfolded_bg, B, H, W, C)

        x_weighted_bg = x_weighted_bg.permute(0, 3, 1, 2)

        out = self.output_cbr(x_weighted_bg)

        return out

    def compute_attention(self, feature_map, B, H, W, C, feature_type):

        attn_layer = self.attn_fg if feature_type == 'fg' else self.attn_bg
        h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)

        feature_map_pooled = self.pool(feature_map.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

        attn = attn_layer(feature_map_pooled).reshape(B, h * w, self.num_heads,
                                                      self.kernel_size * self.kernel_size,
                                                      self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4)
        attn = attn * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        return attn

    def apply_attention(self, attn, v, B, H, W, C):

        x_weighted = (attn @ v).permute(0, 1, 4, 3, 2).reshape(
            B, self.dim * self.kernel_size * self.kernel_size, -1)
        x_weighted = F.fold(x_weighted, output_size=(H, W), kernel_size=self.kernel_size,
                            padding=self.padding, stride=self.stride)
        x_weighted = self.proj(x_weighted.permute(0, 2, 3, 1))
        x_weighted = self.proj_drop(x_weighted)
        return x_weighted


# ==================== 预处理模块 ====================

class CDFAPreprocess(nn.Module):
    """CDFA预处理模块:调整特征图尺寸"""
    def __init__(self, in_c, out_c, up_scale):
        super().__init__()
        up_times = int(math.log2(up_scale))
        self.preprocess = nn.Sequential()
        self.c1 = CBR(in_c, out_c, kernel_size=3, padding=1)
        for i in range(up_times):
            self.preprocess.add_module(f'up_{i}', nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True))
            self.preprocess.add_module(f'conv_{i}', CBR(out_c, out_c, kernel_size=3, padding=1))

    def forward(self, x):
        x = self.c1(x)
        x = self.preprocess(x)
        return x


# ==================== 测试函数 ====================

def test_cdfa_module():
    """测试CDFA模块"""
    print("=" * 60)
    print("测试 CDFA 模块")
    print("=" * 60)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # 创建测试数据
    batch_size = 2
    channels = 128
    height, width = 16, 16
    
    x = torch.randn(batch_size, channels, height, width).to(device)
    fg = torch.randn(batch_size, channels, height, width).to(device)
    bg = torch.randn(batch_size, channels, height, width).to(device)
    
    print(f"\n输入形状:")
    print(f"  主特征 x: {x.shape}")
    print(f"  前景特征 fg: {fg.shape}")
    print(f"  背景特征 bg: {bg.shape}")
    
    # 创建CDFA模块
    cdfa = ContrastDrivenFeatureAggregation(
        in_c=channels,
        dim=channels,
        num_heads=4,
        kernel_size=3,
        padding=1,
        stride=1
    ).to(device)
    
    # 前向传播
    with torch.no_grad():
        output = cdfa(x, fg, bg)
    
    print(f"\n输出形状:")
    print(f"  增强特征: {output.shape}")
    print("✓ CDFA模块测试通过!")
    return output


def test_decouple_layer():
    """测试解耦层"""
    print("\n" + "=" * 60)
    print("测试 DecoupleLayer 模块")
    print("=" * 60)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 创建测试数据
    batch_size = 2
    in_channels = 1024
    out_channels = 128
    height, width = 16, 16
    
    x = torch.randn(batch_size, in_channels, height, width).to(device)
    print(f"\n输入形状: {x.shape}")
    
    # 创建解耦层
    decouple = DecoupleLayer(in_c=in_channels, out_c=out_channels).to(device)
    
    # 前向传播
    with torch.no_grad():
        f_fg, f_bg, f_uc = decouple(x)
    
    print(f"\n输出形状:")
    print(f"  前景特征 f_fg: {f_fg.shape}")
    print(f"  背景特征 f_bg: {f_bg.shape}")
    print(f"  不确定性特征 f_uc: {f_uc.shape}")
    print("✓ DecoupleLayer模块测试通过!")
    return f_fg, f_bg, f_uc


def test_dilated_conv():
    """测试空洞卷积模块"""
    print("\n" + "=" * 60)
    print("测试 dilated_conv 模块")
    print("=" * 60)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 创建测试数据
    batch_size = 2
    in_channels = 64
    out_channels = 128
    height, width = 32, 32
    
    x = torch.randn(batch_size, in_channels, height, width).to(device)
    print(f"\n输入形状: {x.shape}")
    
    # 创建空洞卷积模块
    dconv = dilated_conv(in_c=in_channels, out_c=out_channels).to(device)
    
    # 前向传播
    with torch.no_grad():
        output = dconv(x)
    
    print(f"\n输出形状: {output.shape}")
    print("✓ dilated_conv模块测试通过!")
    return output


def test_auxiliary_head():
    """测试辅助头"""
    print("\n" + "=" * 60)
    print("测试 AuxiliaryHead 模块")
    print("=" * 60)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 创建测试数据
    batch_size = 2
    channels = 128
    height, width = 16, 16
    
    f_fg = torch.randn(batch_size, channels, height, width).to(device)
    f_bg = torch.randn(batch_size, channels, height, width).to(device)
    f_uc = torch.randn(batch_size, channels, height, width).to(device)
    
    print(f"\n输入形状:")
    print(f"  f_fg: {f_fg.shape}")
    print(f"  f_bg: {f_bg.shape}")
    print(f"  f_uc: {f_uc.shape}")
    
    # 创建辅助头
    aux_head = AuxiliaryHead(in_c=channels).to(device)
    
    # 前向传播
    with torch.no_grad():
        mask_fg, mask_bg, mask_uc = aux_head(f_fg, f_bg, f_uc)
    
    print(f"\n输出形状:")
    print(f"  前景mask: {mask_fg.shape}")
    print(f"  背景mask: {mask_bg.shape}")
    print(f"  不确定性mask: {mask_uc.shape}")
    print("✓ AuxiliaryHead模块测试通过!")
    return mask_fg, mask_bg, mask_uc


def test_attention_modules():
    """测试注意力机制模块"""
    print("\n" + "=" * 60)
    print("测试 注意力机制 模块")
    print("=" * 60)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 创建测试数据
    batch_size = 2
    channels = 128
    height, width = 32, 32
    
    x = torch.randn(batch_size, channels, height, width).to(device)
    print(f"\n输入形状: {x.shape}")
    
    # 测试通道注意力
    ca = channel_attention(in_planes=channels).to(device)
    with torch.no_grad():
        x_ca = ca(x)
    print(f"通道注意力输出: {x_ca.shape}")
    
    # 测试空间注意力
    sa = spatial_attention(kernel_size=7).to(device)
    with torch.no_grad():
        x_sa = sa(x)
    print(f"空间注意力输出: {x_sa.shape}")
    
    print("✓ 注意力机制模块测试通过!")


def test_integration():
    """集成测试:模拟完整的特征处理流程"""
    print("\n" + "=" * 60)
    print("集成测试:完整特征处理流程")
    print("=" * 60)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 模拟backbone输出的特征
    batch_size = 1
    x4 = torch.randn(batch_size, 1024, 16, 16).to(device)  # 最后一层特征
    x3 = torch.randn(batch_size, 512, 32, 32).to(device)   # 倒数第二层
    x2 = torch.randn(batch_size, 256, 64, 64).to(device)   # 第三层
    x1 = torch.randn(batch_size, 64, 128, 128).to(device)  # 第一层
    
    print(f"\n输入特征层级:")
    print(f"  x1: {x1.shape}")
    print(f"  x2: {x2.shape}")
    print(f"  x3: {x3.shape}")
    print(f"  x4: {x4.shape}")
    
    # 1. 特征解耦
    decouple = DecoupleLayer(in_c=1024, out_c=128).to(device)
    f_fg, f_bg, f_uc = decouple(x4)
    print(f"\n1. 解耦后:")
    print(f"  f_fg: {f_fg.shape}, f_bg: {f_bg.shape}, f_uc: {f_uc.shape}")
    
    # 2. 预处理特征(调整到不同层级)
    preprocess_fg3 = CDFAPreprocess(128, 128, 2).to(device)  # 放大2倍
    preprocess_bg3 = CDFAPreprocess(128, 128, 2).to(device)
    f_fg3 = preprocess_fg3(f_fg)
    f_bg3 = preprocess_bg3(f_bg)
    print(f"\n2. 预处理后 (放大到x3层级):")
    print(f"  f_fg3: {f_fg3.shape}, f_bg3: {f_bg3.shape}")
    
    # 3. 空洞卷积特征增强
    dconv3 = dilated_conv(512, 128).to(device)
    d3 = dconv3(x3)
    print(f"\n3. 特征增强后:")
    print(f"  d3: {d3.shape}")
    
    # 4. CDFA特征聚合
    cdfa3 = ContrastDrivenFeatureAggregation(128, 128, 4).to(device)
    f3 = cdfa3(d3, f_fg3, f_bg3)
    print(f"\n4. CDFA聚合后:")
    print(f"  f3: {f3.shape}")
    
    # 5. 辅助头预测
    aux_head = AuxiliaryHead(128).to(device)
    mask_fg, mask_bg, mask_uc = aux_head(f_fg, f_bg, f_uc)
    print(f"\n5. 辅助预测:")
    print(f"  mask_fg: {mask_fg.shape}")
    print(f"  mask_bg: {mask_bg.shape}")
    print(f"  mask_uc: {mask_uc.shape}")
    
    print("\n" + "=" * 60)
    print("✓ 集成测试通过! 所有即插即用模块工作正常")
    print("=" * 60)


if __name__ == "__main__":
    print("\n")
    print("╔" + "=" * 58 + "╗")
    print("║" + " " * 15 + "即插即用模块测试" + " " * 25 + "║")
    print("╚" + "=" * 58 + "╝")
    print("\n")
    
    try:
        # 测试各个模块
        test_cdfa_module()
        test_decouple_layer()
        test_dilated_conv()
        test_auxiliary_head()
        test_attention_modules()
        test_integration()
        
        print("\n" + "=" * 60)
        print("所有测试完成! ✓")
        print("=" * 60)
        print("\n这些模块可以即插即用地用于不同的backbone架构:")
        print("  - ResNet (network/model.py)")
        print("  - PVTv2 (network_pvt/model.py)")
        print("  - 或其他自定义backbone")
        print("\n")
        
    except Exception as e:
        print(f"\n❌ 测试失败: {e}")
        import traceback
        traceback.print_exc()



Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐