AI模型瘦身8倍还更聪明?这个黑科技让手机跑大模型成为可能!
本文提出了一种量化的脉冲驱动Transformer模块(QSD-Transformer),用于解决脉冲神经网络在资源受限设备上的部署挑战。该模块通过将32位权重量化到2-4位,实现了显著的能耗降低和模型压缩。核心创新包括:1)量化脉冲驱动自注意力机制(Q-SDSA),大幅降低计算复杂度;2)信息增强LIF神经元(IE-LIF),在训练时保持丰富信息表达,推理时切换为高效二进制模式;3)细粒度蒸馏方
QSD-Transformer模块
论文《QUANTIZED SPIKE-DRIVEN TRANSFORMER》
论文地址: https://arxiv.org/pdf/2501.13492
发表期刊: ICLR
深度交流Q裙:1051849847
全网同名 【大嘴带你水论文】 B站定时发布详细讲解视频
详细代码见文章最后
1、作用
QSD-Transformer是一个量化的脉冲驱动Transformer模块,专门解决脉冲神经网络(SNN)在资源受限设备上部署的挑战。该模块通过将32位权重量化到低位宽(2-4位),实现了显著的能耗降低和模型压缩,同时保持高性能。QSD-Transformer解决了传统SNN Transformer模型参数量大、计算复杂度高的问题,例如Spikformer v2需要173M参数和1384MB内存。通过双层优化策略,包括信息增强LIF神经元(IE-LIF)和细粒度蒸馏方案(FGD),QSD-Transformer有效缓解了量化过程中的脉冲信息失真(SID)问题。实验表明,在ImageNet数据集上,QSD-Transformer达到80.3%的top-1准确率,同时实现6.0×功耗降低和8.1×模型尺寸减少,为边缘设备上的高效神经形态计算提供了新的解决方案。
图1. QSD-Transformer的整体架构
2、核心创新
1、量化脉冲驱动自注意力(Q-SDSA):
这是QSD-Transformer的核心技术,简单来说就是把原本需要32位存储的权重参数压缩到2-4位,就像把一个高清图片压缩成小文件一样。传统的注意力机制计算量很大,而Q-SDSA通过脉冲神经网络的特性,用0和1的二进制脉冲来代替复杂的浮点数计算,大大降低了计算复杂度。但是这种压缩会带来信息损失的问题,就像压缩图片会降低画质一样。
2、信息增强LIF神经元(IE-LIF):
这是解决信息损失的关键技术。IE-LIF神经元很聪明,它在训练的时候使用多个数值来保持丰富的信息表达能力,就像用彩色画笔作画;但在实际推理时又切换到简单的0和1二进制模式,就像用黑白笔画画,这样既保证了性能又实现了高效率。它还有一个自动调节机制,能够修正信息分布,确保压缩后的模型不会丢失太多重要信息。
3、细粒度蒸馏方案(FGD):
这是一种"师傅带徒弟"的训练方法。把性能强大但笨重的传统神经网络当作"老师",把轻量化的QSD-Transformer当作"学生"。通过精细的知识传递,让学生模型学会老师的核心能力,但用更简单高效的方式来实现。这种方法确保了压缩后的模型仍然能保持很好的性能表现。
3、代码
QSD-Transformer作为一个轻量化且高效的脉冲神经网络模块,在多个计算机视觉领域都展现出了优异的性能:
1、图像分类任务:
在ImageNet数据集上取得了80.3%的top-1准确率,同时模型大小仅为6.8M参数,功耗降低了6倍。这使得它非常适合部署在手机、平板等移动设备上进行实时图像识别。
2、目标检测任务:
在COCO数据集上的目标检测实验中,QSD-Transformer超越了现有的脉冲神经网络方法5.8%,能够准确识别和定位图像中的多个物体,适用于自动驾驶、安防监控等场景。
3、语义分割任务:
在ADE20K数据集上的分割任务中表现出色,能够对图像进行像素级的精确分割,可应用于医学影像分析、卫星图像处理等需要精细分割的领域。
4、迁移学习应用:
在CIFAR-10/100和神经形态数据集上都取得了最佳性能,说明该模块具有很强的通用性,可以快速适应新的任务领域,降低从零开始训练的成本。
5、边缘计算设备:
由于其极低的功耗和存储需求,QSD-Transformer特别适合部署在资源受限的边缘设备上,如IoT传感器、无人机、机器人等,实现本地化的智能处理。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.autograd import Function
from timm.models.layers import DropPath, trunc_normal_
class ReLUX(nn.Module):
"""限制ReLU激活函数"""
def __init__(self, thre=8):
super(ReLUX, self).__init__()
self.thre = thre
def forward(self, input):
return torch.clamp(input, 0, self.thre)
relu4 = ReLUX(thre=4)
class MultiSpike(Function):
"""多位脉冲函数"""
@staticmethod
def forward(ctx, input, lens):
ctx.save_for_backward(input)
ctx.lens = lens
return torch.floor(relu4(input) + 0.5)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
temp1 = 0 < input
temp2 = input < ctx.lens
return grad_input * temp1.float() * temp2.float(), None
class IE_LIF(nn.Module):
"""信息增强LIF神经元"""
def __init__(self, lens=4, spike=MultiSpike):
super().__init__()
self.lens = lens
self.spike = spike
self.training_mode = True # 训练时使用多位,推理时使用二进制
def forward(self, inputs):
if self.training_mode:
# 训练时:多位脉冲
return self.spike.apply(4 * inputs, self.lens) / 4
else:
# 推理时:二进制脉冲
return (inputs > 0).float()
class MPRF(nn.Module):
"""膜电位修正函数"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.alpha = nn.Parameter(torch.zeros(dim))
def forward(self, x):
"""
膜电位修正
Args:
x: 膜电位 (B, N, C)
Returns:
修正后的膜电位
"""
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True) + 1e-6
x_norm = (x - mean) / std
return x_norm * self.gamma + self.alpha
class LSQQuantizer(nn.Module):
"""LSQ量化器"""
def __init__(self, bit=4, all_positive=False):
super().__init__()
self.bit = bit
self.all_positive = all_positive
if all_positive:
self.thd_neg = 0
self.thd_pos = 2 ** bit - 1
else:
self.thd_neg = - 2 ** (bit - 1)
self.thd_pos = 2 ** (bit - 1) - 1
self.s = nn.Parameter(torch.ones(1))
def forward(self, x):
if self.training:
# 训练时量化
x_q = torch.clamp(torch.round(x / self.s), self.thd_neg, self.thd_pos)
x_dq = x_q * self.s
# 直通估计器
x_dq = x + (x_dq - x).detach()
return x_dq
else:
# 推理时量化
x_q = torch.clamp(torch.round(x / self.s), self.thd_neg, self.thd_pos)
return x_q * self.s
class QuantizedLinear(nn.Module):
"""量化线性层"""
def __init__(self, in_features, out_features, bias=True, bit=4):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
self.quantizer = LSQQuantizer(bit=bit)
# 初始化
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, x):
w_q = self.quantizer(self.weight)
return F.linear(x, w_q, self.bias)
class QuantizedConv2d(nn.Module):
"""量化卷积层"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, bias=True, bit=4):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None
self.quantizer = LSQQuantizer(bit=bit)
# 初始化
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, x):
w_q = self.quantizer(self.weight)
return F.conv2d(x, w_q, self.bias, self.stride, self.padding)
class Q_SDSA(nn.Module):
"""量化脉冲驱动自注意力"""
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., bit=4):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
# 量化的QKV投影
self.q_conv = QuantizedConv2d(dim, dim, 1, bias=qkv_bias, bit=bit)
self.k_conv = QuantizedConv2d(dim, dim, 1, bias=qkv_bias, bit=bit)
self.v_conv = QuantizedConv2d(dim, dim, 1, bias=qkv_bias, bit=bit)
# IE-LIF神经元
self.q_lif = IE_LIF()
self.k_lif = IE_LIF()
self.v_lif = IE_LIF()
# 膜电位修正函数
self.q_mprf = MPRF(dim)
self.k_mprf = MPRF(dim)
self.v_mprf = MPRF(dim)
# 输出投影
self.proj = QuantizedLinear(dim, dim, bit=bit)
self.proj_lif = IE_LIF()
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
"""
前向传播
Args:
x: 输入特征 (B, H, W, C)
Returns:
输出特征 (B, H, W, C)
"""
B, H, W, C = x.shape
# 转换为卷积格式
x = x.permute(0, 3, 1, 2) # (B, C, H, W)
# QKV投影
q = self.q_conv(x) # (B, C, H, W)
k = self.k_conv(x)
v = self.v_conv(x)
# 膜电位修正
q = q.permute(0, 2, 3, 1) # (B, H, W, C)
k = k.permute(0, 2, 3, 1)
v = v.permute(0, 2, 3, 1)
q = self.q_mprf(q)
k = self.k_mprf(k)
v = self.v_mprf(v)
# IE-LIF激活
q_s = self.q_lif(q)
k_s = self.k_lif(k)
v_s = self.v_lif(v)
# 重塑为注意力格式
q_s = q_s.reshape(B, H*W, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k_s = k_s.reshape(B, H*W, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v_s = v_s.reshape(B, H*W, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# 脉冲驱动注意力计算
attn = (q_s @ k_s.transpose(-2, -1)) * self.scale
attn = self.attn_drop(attn)
# 注意力输出
x = (attn @ v_s).transpose(1, 2).reshape(B, H*W, C)
# 输出投影
x = self.proj(x)
x = self.proj_lif(x)
x = self.proj_drop(x)
# 重塑回原始格式
x = x.reshape(B, H, W, C)
return x
class QSDTransformerBlock(nn.Module):
"""QSD-Transformer块"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
attn_drop=0., drop_path=0., bit=4):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Q_SDSA(dim, num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop, bit=bit)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
QuantizedLinear(dim, mlp_hidden_dim, bit=bit),
IE_LIF(),
nn.Dropout(drop),
QuantizedLinear(mlp_hidden_dim, dim, bit=bit),
IE_LIF(),
nn.Dropout(drop)
)
def forward(self, x):
"""前向传播"""
# 注意力分支
x = x + self.drop_path(self.attn(self.norm1(x)))
# MLP分支
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class FineGrainedDistillation(nn.Module):
"""细粒度蒸馏损失"""
def __init__(self):
super().__init__()
def compute_similarity_matrix(self, x):
"""计算相似性矩阵"""
# x: (B, N, C)
x_norm = F.normalize(x, p=2, dim=-1)
sim = torch.bmm(x_norm, x_norm.transpose(1, 2))
return sim
def forward(self, student_qkv, teacher_qkv):
"""
计算细粒度蒸馏损失
Args:
student_qkv: 学生模型的QKV [(B,N,C), (B,N,C), (B,N,C)]
teacher_qkv: 教师模型的QKV [(B,N,C), (B,N,C), (B,N,C)]
Returns:
蒸馏损失
"""
loss = 0.0
for s_feat, t_feat in zip(student_qkv, teacher_qkv):
s_sim = self.compute_similarity_matrix(s_feat)
t_sim = self.compute_similarity_matrix(t_feat)
loss += F.mse_loss(s_sim, t_sim)
return loss / len(student_qkv)
class QSDTransformer(nn.Module):
"""QSD-Transformer主模型"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., bit=4):
super().__init__()
self.num_classes = num_classes
self.embed_dim = embed_dim
self.patch_size = patch_size
# Patch嵌入
self.patch_embed = QuantizedConv2d(in_chans, embed_dim, patch_size, patch_size, bit=bit)
num_patches = (img_size // patch_size) ** 2
# 位置嵌入
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
# Transformer块
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList([
QSDTransformerBlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], bit=bit
) for i in range(depth)
])
# 分类头
self.norm = nn.LayerNorm(embed_dim)
self.head = QuantizedLinear(embed_dim, num_classes, bit=bit) if num_classes > 0 else nn.Identity()
# 细粒度蒸馏
self.distillation = FineGrainedDistillation()
# 初始化
trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (QuantizedLinear, QuantizedConv2d)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
"""特征提取"""
# Patch嵌入
x = self.patch_embed(x) # (B, C, H, W)
B, C, H, W = x.shape
x = x.reshape(B, C, H*W).transpose(1, 2) # (B, N, C)
# 位置嵌入
x = x + self.pos_embed
x = self.pos_drop(x)
# 转换为块格式
x = x.reshape(B, H, W, C)
# 通过Transformer块
for block in self.blocks:
x = block(x)
# 全局平均池化
x = x.reshape(B, H*W, C)
x = self.norm(x)
x = x.mean(dim=1) # 全局平均池化
return x
def forward(self, x, teacher_features=None):
"""前向传播"""
x = self.forward_features(x)
x = self.head(x)
if teacher_features is not None and self.training:
# 计算蒸馏损失
distill_loss = self.distillation(x, teacher_features)
return x, distill_loss
return x
def set_inference_mode(self):
"""设置推理模式"""
for module in self.modules():
if isinstance(module, IE_LIF):
module.training_mode = False
# 测试代码
if __name__ == '__main__':
# 创建模型
model = QSDTransformer(
img_size=224,
patch_size=16,
embed_dim=384,
depth=6,
num_heads=6,
num_classes=1000,
bit=4
)
# 创建测试数据
batch_size = 2
x = torch.randn(batch_size, 3, 224, 224)
# 前向传播
output = model(x)
# 设置推理模式
model.set_inference_mode()
model.eval()
with torch.no_grad():
output_inference = model(x)
# 打印结果
print('输入尺寸:', x.size())
print('训练输出尺寸:', output.size())
print('推理输出尺寸:', output_inference.size())
print('参数数量:', sum(p.numel() for p in model.parameters()) / 1e6, 'M')
详细代码 gitcode地址:https://gitcode.com/2301_80107842/research
更多推荐
所有评论(0)