权重初始化
·
权重初始化复用指南
概述
本指南基于VAN(Visual Attention Network)项目中经过验证的权重初始化方法,提供了一个通用、高效的权重初始化解决方案,特别适用于包含卷积层、全连接层和归一化层的深度学习模型。
核心优势
- ✅ 现代最佳实践: 结合了He初始化和截断正态分布
- ✅ 分组卷积支持: 正确处理深度可分离卷积等分组操作
- ✅ 全面覆盖: 支持Conv2d, Linear, LayerNorm等常用层
- ✅ 训练稳定: 经过VAN项目验证,确保稳定的梯度流
- ✅ 即插即用: 可直接集成到任何PyTorch模型中
快速开始
1. 基础版本(推荐)
import torch
import torch.nn as nn
import math
from timm.models.layers import trunc_normal_
class YourModel(nn.Module):
def __init__(self):
super().__init__()
# 您的模型定义
self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, groups=1)
self.dwconv = nn.Conv2d(64, 64, 3, 1, 1, groups=64) # 深度卷积
self.norm1 = nn.LayerNorm(64)
self.fc1 = nn.Linear(64, 10)
# 应用权重初始化
self.apply(self._init_weights)
def _init_weights(self, m):
"""VAN项目验证的权重初始化方法"""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups # 关键:正确处理分组卷积
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
2. 无依赖版本
如果不想依赖timm库,可以使用以下自实现版本:
import torch
import torch.nn as nn
import math
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
"""截断正态分布初始化(简化版本)"""
with torch.no_grad():
tensor.normal_(mean, std)
tensor.clamp_(min=a*std + mean, max=b*std + mean)
return tensor
class YourModel(nn.Module):
def __init__(self):
super().__init__()
# 模型定义
self.apply(self._init_weights)
def _init_weights(self, m):
"""无依赖的权重初始化方法"""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
技术原理详解
1. Linear层初始化
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
原理说明:
- 使用截断正态分布,避免极值权重
- 标准差0.02确保初始权重较小,有利于训练稳定
- 偏置初始化为0是标准做法
2. LayerNorm初始化
nn.init.constant_(m.bias, 0) # β = 0
nn.init.constant_(m.weight, 1.0) # γ = 1
原理说明:
- 权重初始化为1.0,偏置初始化为0
- 确保LayerNorm在训练开始时相当于恒等变换
- 让网络从"无归一化"状态开始,逐渐学习合适的参数
3. Conv2d层初始化(核心创新)
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups # 关键修正
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
原理说明:
- 基于改进的He初始化(Kaiming初始化)
fan_out //= m.groups
是关键创新,正确处理分组卷积- 确保分组卷积的权重尺度与标准卷积保持一致
分组卷积示例:
# 深度可分离卷积: groups = input_channels
dwconv = nn.Conv2d(64, 64, 3, groups=64)
# fan_out = 3 * 3 * 64 = 576
# fan_out //= 64 = 9
# std = sqrt(2.0 / 9) ≈ 0.47
# 标准卷积: groups = 1
conv = nn.Conv2d(64, 64, 3, groups=1)
# fan_out = 3 * 3 * 64 = 576
# std = sqrt(2.0 / 576) ≈ 0.059
适用场景
✅ 推荐使用的场景
-
包含分组卷积的模型
- MobileNet系列
- EfficientNet系列
- 任何使用深度可分离卷积的模型
-
视觉模型
- CNN分类网络
- 目标检测模型
- 语义分割网络
-
Transformer变体
- Vision Transformer
- 卷积-注意力混合模型
-
需要稳定训练的模型
- 深层网络
- 复杂架构的模型
❌ 不推荐使用的场景
-
特殊层类型为主的模型
- 主要包含Embedding, LSTM, GRU等
- 需要特定初始化策略的层
-
已有成熟初始化方案的模型
- 预训练模型微调
- 已经验证过其他初始化方法的模型
使用示例
示例1:ResNet风格的模型
class ResNetBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
# BatchNorm会自动正确初始化,无需手动设置
示例2:MobileNet风格的模型
class MobileNetBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
# 深度可分离卷积
self.dwconv = nn.Conv2d(in_channels, in_channels, 3, 1, 1, groups=in_channels)
self.pwconv = nn.Conv2d(in_channels, out_channels, 1)
self.norm = nn.LayerNorm(out_channels)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups # 对深度卷积很重要
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
扩展与自定义
添加新的层类型支持
def _init_weights(self, m):
"""扩展版本:支持更多层类型"""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
# BatchNorm通常不需要手动初始化,但如果需要:
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
# 为Embedding层添加支持
nn.init.normal_(m.weight, mean=0, std=0.02)
参数调优
def _init_weights_custom(self, m, linear_std=0.02, conv_gain=2.0):
"""可调参数的初始化版本"""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=linear_std)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(conv_gain / fan_out))
if m.bias is not None:
m.bias.data.zero_()
# ... 其他层类型
常见问题与解决方案
Q1: 训练时出现梯度爆炸怎么办?
A1: 可能是初始化标准差过大,尝试以下调整:
# 减小Linear层的初始化标准差
trunc_normal_(m.weight, std=.01) # 从0.02减小到0.01
# 或者调整Conv2d的gain参数
m.weight.data.normal_(0, math.sqrt(1.0 / fan_out)) # 从2.0减小到1.0
Q2: 训练时出现梯度消失怎么办?
A2: 可能是初始化过小,尝试以下调整:
# 增大初始化标准差
trunc_normal_(m.weight, std=.05) # 从0.02增大到0.05
# 或者使用更大的gain
m.weight.data.normal_(0, math.sqrt(3.0 / fan_out)) # 从2.0增大到3.0
Q3: 为什么分组卷积需要特殊处理?
A3: 分组卷积的实际连接数少于标准卷积:
# 标准卷积:每个输出神经元连接所有输入通道
# 分组卷积:每个输出神经元只连接部分输入通道
# 因此需要 fan_out //= m.groups 来修正实际的连接数
Q4: 如何验证初始化是否正确?
A4: 可以检查前向传播的激活值分布:
def check_activation_stats(model, input_tensor):
"""检查激活值统计信息"""
hooks = []
stats = []
def hook_fn(module, input, output):
if isinstance(output, torch.Tensor):
stats.append({
'mean': output.mean().item(),
'std': output.std().item(),
'layer': str(module)
})
# 注册hooks
for module in model.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
hooks.append(module.register_forward_hook(hook_fn))
# 前向传播
with torch.no_grad():
model(input_tensor)
# 清理hooks
for hook in hooks:
hook.remove()
# 打印统计信息
for stat in stats:
print(f"Layer: {stat['layer']}")
print(f"Mean: {stat['mean']:.4f}, Std: {stat['std']:.4f}")
print("-" * 50)
return stats
# 使用示例
model = YourModel()
input_tensor = torch.randn(1, 3, 224, 224)
stats = check_activation_stats(model, input_tensor)
性能对比
初始化方法对比表
初始化方法 | Linear层 | Conv2d层 | 分组卷积支持 | 归一化层 | 训练稳定性 |
---|---|---|---|---|---|
VAN方法(推荐) | trunc_normal | He + groups修正 | ✅ | 标准初始化 | ⭐⭐⭐⭐⭐ |
PyTorch默认 | Xavier uniform | Kaiming uniform | ❌ | 标准初始化 | ⭐⭐⭐ |
简单He初始化 | He normal | He normal | ❌ | 标准初始化 | ⭐⭐⭐⭐ |
Xavier初始化 | Xavier normal | Xavier normal | ❌ | 标准初始化 | ⭐⭐⭐ |
总结
这个权重初始化方法是现代深度学习的最佳实践之一,特别适合:
- 现代CNN架构:正确处理分组卷积
- 视觉任务:经过VAN项目验证
- 训练稳定性:平衡的梯度流
- 通用性:适用于大多数PyTorch模型
建议在新项目中优先考虑使用这个初始化方法,它可以显著提升训练的稳定性和收敛速度。
参考资料
更多推荐
所有评论(0)