1. 代码结构与参数传递


关键代码片段分析:

def block(x, filters, strides=1, groups=32, conv_shortcut=True):
    if conv_shortcut:
        shortcut = Conv2D(filters * 2, kernel_size=1, strides=strides)(x)  # 调整通道数为filters*2
    else:
        shortcut = x  # 直接使用输入作为捷径

    # 主路径:最后一层输出通道数为filters*2
    x = Conv2D(filters * 2, kernel_size=1)(x)
    x = Add()([x, shortcut])  # 要求x和shortcut的通道数一致
    return x

def stack(x, filters, blocks, strides, groups=32):
    x = block(x, filters, strides=strides, conv_shortcut=True)  # 第一个块调整通道数
    for _ in range(1, blocks):
        x = block(x, filters, conv_shortcut=False)  # 后续块复用通道数
    return x

2. 通道数对齐的关键逻辑

(1) 第一个块显式对齐通道数
stack函数的第一个块:  
  参数conv_shortcut=True,执行以下操作:  
    输入`x`的通道数为`C_in`(例如256),通过1x1卷积升维至`filters * 2`(例如128*2=256)。  
    输出通道数与输入通道数一致(`C_in = filters * 2`)。  

(2) 后续块复用通道数
从第二个块开始:  
   参数`conv_shortcut=False`,直接使用输入作为捷径(`shortcut = x`)。  
  输入通道数已对齐**:由于第一个块的输出通道数为`filters * 2`,后续块的输入通道数自然为`filters * 2`。  
  主路径的输出通道数**:最后一层卷积固定输出`filters * 2`通道(与输入一致)。  
   因此,`Add()`操作的`x`和`shortcut`的通道数始终相同,无需调整。

3. 数学验证

假设某`stack`的参数为`filters=128`,输入通道数`C_in=256`:  
第一个块**:  
   主路径输出通道数:`128 * 2 = 256`。  
  捷径通道数:通过1x1卷积调整为`256`。  
  `Add()`操作合法(256 vs 256)。  
后续块:  
  输入通道数已为`256`(等于`128 * 2`)。  
  主路径输出通道数仍为`256`,捷径直接使用输入(256)。  
  Add()操作合法(256 vs 256)。  

4. 总结

不报错的核心原因:  
1. 首个块显式对齐通道数:通过`conv_shortcut=True`的1x1卷积,确保输入输出通道数一致。 
2. 后续块复用对齐后的通道数:输入和主路径输出均为`filters * 2`,无需额外调整。  
3. 参数传递一致性:同一`stack`中所有块的`filters`参数相同,保证通道数统一。  

这种设计既符合残差网络的基本思想(恒等映射),又通过首个块的显式调整避免了维度不匹配问题,是ResNet/ResNeXt系列的核心优化技巧。

可以在 conv_shortcut=False 时,打印 x.shapeshortcut.shape

print("x shape:", x.shape)
print("shortcut shape:", shortcut.shape)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐