1.模块代码调整:

class CoordGate(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(CoordGate, self).__init__()
        self.h_conv = nn.Conv2d(1, 1, kernel_size=7, stride=1, padding=3)
        self.w_conv = nn.Conv2d(1, 1, kernel_size=7, stride=1, padding=3)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        # 生成坐标注意力
        y_coord = torch.linspace(-1, 1, h).view(1, 1, h, 1).expand(b, 1, h, w).to(x.device)
        x_coord = torch.linspace(-1, 1, w).view(1, 1, 1, w).expand(b, 1, h, w).to(x.device)
        y_att = self.h_conv(y_coord)
        x_att = self.w_conv(x_coord)
        coord_att = self.sigmoid(y_att + x_att)
        # 通道注意力
        avg_pool = self.avg_pool(x).view(b, c)
        channel_att = self.sigmoid(self.mlp(avg_pool)).view(b, c, 1, 1)
        # 组合空间和通道注意力
        out = x * coord_att * channel_att
        return out

2.插入和改进方法:

tasks.py文件注册导入该模块后,将代码args = [c1, c2, *args[1:]]改为

args = [c1, *args[1:]] if m in {CBAM,CoordGate} else [c1, c2, *args[1:]]即可

参考

3.from ultralytics.nn.modules import Coordgate即可

最终,训练正常

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐