YOLOV8添加模块COORDGATE
args = [c1, *args[1:]] if m in {CBAM,CoordGate} else [c1, c2, *args[1:]]即可。tasks.py文件注册导入该模块后,将代码args = [c1, c2, *args[1:]]改为。3.from ultralytics.nn.modules import Coordgate即可。
·
1.模块代码调整:
class CoordGate(nn.Module):
def __init__(self, in_channels, reduction=16):
super(CoordGate, self).__init__()
self.h_conv = nn.Conv2d(1, 1, kernel_size=7, stride=1, padding=3)
self.w_conv = nn.Conv2d(1, 1, kernel_size=7, stride=1, padding=3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.mlp = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction, in_channels)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, h, w = x.size()
# 生成坐标注意力
y_coord = torch.linspace(-1, 1, h).view(1, 1, h, 1).expand(b, 1, h, w).to(x.device)
x_coord = torch.linspace(-1, 1, w).view(1, 1, 1, w).expand(b, 1, h, w).to(x.device)
y_att = self.h_conv(y_coord)
x_att = self.w_conv(x_coord)
coord_att = self.sigmoid(y_att + x_att)
# 通道注意力
avg_pool = self.avg_pool(x).view(b, c)
channel_att = self.sigmoid(self.mlp(avg_pool)).view(b, c, 1, 1)
# 组合空间和通道注意力
out = x * coord_att * channel_att
return out
2.插入和改进方法:
tasks.py文件注册导入该模块后,将代码args = [c1, c2, *args[1:]]改为
args = [c1, *args[1:]] if m in {CBAM,CoordGate} else [c1, c2, *args[1:]]即可
参考

3.from ultralytics.nn.modules import Coordgate即可
最终,训练正常
更多推荐


所有评论(0)