QSD-Transformer模块

论文《QUANTIZED SPIKE-DRIVEN TRANSFORMER》
论文地址: https://arxiv.org/pdf/2501.13492
发表期刊: ICLR
深度交流Q裙:1051849847
全网同名 【大嘴带你水论文】 B站定时发布详细讲解视频
详细代码见文章最后

在这里插入图片描述

1、作用

QSD-Transformer是一个量化的脉冲驱动Transformer模块,专门解决脉冲神经网络(SNN)在资源受限设备上部署的挑战。该模块通过将32位权重量化到低位宽(2-4位),实现了显著的能耗降低和模型压缩,同时保持高性能。QSD-Transformer解决了传统SNN Transformer模型参数量大、计算复杂度高的问题,例如Spikformer v2需要173M参数和1384MB内存。通过双层优化策略,包括信息增强LIF神经元(IE-LIF)和细粒度蒸馏方案(FGD),QSD-Transformer有效缓解了量化过程中的脉冲信息失真(SID)问题。实验表明,在ImageNet数据集上,QSD-Transformer达到80.3%的top-1准确率,同时实现6.0×功耗降低和8.1×模型尺寸减少,为边缘设备上的高效神经形态计算提供了新的解决方案。

在这里插入图片描述

图1. QSD-Transformer的整体架构

2、核心创新

1、量化脉冲驱动自注意力(Q-SDSA)

这是QSD-Transformer的核心技术,简单来说就是把原本需要32位存储的权重参数压缩到2-4位,就像把一个高清图片压缩成小文件一样。传统的注意力机制计算量很大,而Q-SDSA通过脉冲神经网络的特性,用0和1的二进制脉冲来代替复杂的浮点数计算,大大降低了计算复杂度。但是这种压缩会带来信息损失的问题,就像压缩图片会降低画质一样。

2、信息增强LIF神经元(IE-LIF)

这是解决信息损失的关键技术。IE-LIF神经元很聪明,它在训练的时候使用多个数值来保持丰富的信息表达能力,就像用彩色画笔作画;但在实际推理时又切换到简单的0和1二进制模式,就像用黑白笔画画,这样既保证了性能又实现了高效率。它还有一个自动调节机制,能够修正信息分布,确保压缩后的模型不会丢失太多重要信息。

3、细粒度蒸馏方案(FGD)

这是一种"师傅带徒弟"的训练方法。把性能强大但笨重的传统神经网络当作"老师",把轻量化的QSD-Transformer当作"学生"。通过精细的知识传递,让学生模型学会老师的核心能力,但用更简单高效的方式来实现。这种方法确保了压缩后的模型仍然能保持很好的性能表现。

3、代码

QSD-Transformer作为一个轻量化且高效的脉冲神经网络模块,在多个计算机视觉领域都展现出了优异的性能:

1、图像分类任务
在ImageNet数据集上取得了80.3%的top-1准确率,同时模型大小仅为6.8M参数,功耗降低了6倍。这使得它非常适合部署在手机、平板等移动设备上进行实时图像识别。

2、目标检测任务
在COCO数据集上的目标检测实验中,QSD-Transformer超越了现有的脉冲神经网络方法5.8%,能够准确识别和定位图像中的多个物体,适用于自动驾驶、安防监控等场景。

3、语义分割任务
在ADE20K数据集上的分割任务中表现出色,能够对图像进行像素级的精确分割,可应用于医学影像分析、卫星图像处理等需要精细分割的领域。

4、迁移学习应用
在CIFAR-10/100和神经形态数据集上都取得了最佳性能,说明该模块具有很强的通用性,可以快速适应新的任务领域,降低从零开始训练的成本。

5、边缘计算设备
由于其极低的功耗和存储需求,QSD-Transformer特别适合部署在资源受限的边缘设备上,如IoT传感器、无人机、机器人等,实现本地化的智能处理。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.autograd import Function
from timm.models.layers import DropPath, trunc_normal_

class ReLUX(nn.Module):
    """限制ReLU激活函数"""
    def __init__(self, thre=8):
        super(ReLUX, self).__init__()
        self.thre = thre

    def forward(self, input):
        return torch.clamp(input, 0, self.thre)

relu4 = ReLUX(thre=4)

class MultiSpike(Function):
    """多位脉冲函数"""
    @staticmethod
    def forward(ctx, input, lens):
        ctx.save_for_backward(input)
        ctx.lens = lens
        return torch.floor(relu4(input) + 0.5)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp1 = 0 < input
        temp2 = input < ctx.lens
        return grad_input * temp1.float() * temp2.float(), None

class IE_LIF(nn.Module):
    """信息增强LIF神经元"""
    def __init__(self, lens=4, spike=MultiSpike):
        super().__init__()
        self.lens = lens
        self.spike = spike
        self.training_mode = True  # 训练时使用多位,推理时使用二进制

    def forward(self, inputs):
        if self.training_mode:
            # 训练时:多位脉冲
            return self.spike.apply(4 * inputs, self.lens) / 4
        else:
            # 推理时:二进制脉冲
            return (inputs > 0).float()

class MPRF(nn.Module):
    """膜电位修正函数"""
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.alpha = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        """
        膜电位修正
        Args:
            x: 膜电位 (B, N, C)
        Returns:
            修正后的膜电位
        """
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True) + 1e-6
        x_norm = (x - mean) / std
        return x_norm * self.gamma + self.alpha

class LSQQuantizer(nn.Module):
    """LSQ量化器"""
    def __init__(self, bit=4, all_positive=False):
        super().__init__()
        self.bit = bit
        self.all_positive = all_positive
        if all_positive:
            self.thd_neg = 0
            self.thd_pos = 2 ** bit - 1
        else:
            self.thd_neg = - 2 ** (bit - 1)
            self.thd_pos = 2 ** (bit - 1) - 1
        self.s = nn.Parameter(torch.ones(1))

    def forward(self, x):
        if self.training:
            # 训练时量化
            x_q = torch.clamp(torch.round(x / self.s), self.thd_neg, self.thd_pos)
            x_dq = x_q * self.s
            # 直通估计器
            x_dq = x + (x_dq - x).detach()
            return x_dq
        else:
            # 推理时量化
            x_q = torch.clamp(torch.round(x / self.s), self.thd_neg, self.thd_pos)
            return x_q * self.s

class QuantizedLinear(nn.Module):
    """量化线性层"""
    def __init__(self, in_features, out_features, bias=True, bit=4):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
        self.quantizer = LSQQuantizer(bit=bit)

        # 初始化
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, x):
        w_q = self.quantizer(self.weight)
        return F.linear(x, w_q, self.bias)

class QuantizedConv2d(nn.Module):
    """量化卷积层"""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, bias=True, bit=4):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None
        self.quantizer = LSQQuantizer(bit=bit)

        # 初始化
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, x):
        w_q = self.quantizer(self.weight)
        return F.conv2d(x, w_q, self.bias, self.stride, self.padding)

class Q_SDSA(nn.Module):
    """量化脉冲驱动自注意力"""
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., bit=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        # 量化的QKV投影
        self.q_conv = QuantizedConv2d(dim, dim, 1, bias=qkv_bias, bit=bit)
        self.k_conv = QuantizedConv2d(dim, dim, 1, bias=qkv_bias, bit=bit)
        self.v_conv = QuantizedConv2d(dim, dim, 1, bias=qkv_bias, bit=bit)

        # IE-LIF神经元
        self.q_lif = IE_LIF()
        self.k_lif = IE_LIF()
        self.v_lif = IE_LIF()

        # 膜电位修正函数
        self.q_mprf = MPRF(dim)
        self.k_mprf = MPRF(dim)
        self.v_mprf = MPRF(dim)

        # 输出投影
        self.proj = QuantizedLinear(dim, dim, bit=bit)
        self.proj_lif = IE_LIF()

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        """
        前向传播
        Args:
            x: 输入特征 (B, H, W, C)
        Returns:
            输出特征 (B, H, W, C)
        """
        B, H, W, C = x.shape

        # 转换为卷积格式
        x = x.permute(0, 3, 1, 2)  # (B, C, H, W)

        # QKV投影
        q = self.q_conv(x)  # (B, C, H, W)
        k = self.k_conv(x)
        v = self.v_conv(x)

        # 膜电位修正
        q = q.permute(0, 2, 3, 1)  # (B, H, W, C)
        k = k.permute(0, 2, 3, 1)
        v = v.permute(0, 2, 3, 1)

        q = self.q_mprf(q)
        k = self.k_mprf(k)
        v = self.v_mprf(v)

        # IE-LIF激活
        q_s = self.q_lif(q)
        k_s = self.k_lif(k)
        v_s = self.v_lif(v)

        # 重塑为注意力格式
        q_s = q_s.reshape(B, H*W, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k_s = k_s.reshape(B, H*W, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v_s = v_s.reshape(B, H*W, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # 脉冲驱动注意力计算
        attn = (q_s @ k_s.transpose(-2, -1)) * self.scale
        attn = self.attn_drop(attn)

        # 注意力输出
        x = (attn @ v_s).transpose(1, 2).reshape(B, H*W, C)

        # 输出投影
        x = self.proj(x)
        x = self.proj_lif(x)
        x = self.proj_drop(x)

        # 重塑回原始格式
        x = x.reshape(B, H, W, C)

        return x

class QSDTransformerBlock(nn.Module):
    """QSD-Transformer块"""
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
                 attn_drop=0., drop_path=0., bit=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Q_SDSA(dim, num_heads=num_heads, qkv_bias=qkv_bias,
                          attn_drop=attn_drop, proj_drop=drop, bit=bit)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            QuantizedLinear(dim, mlp_hidden_dim, bit=bit),
            IE_LIF(),
            nn.Dropout(drop),
            QuantizedLinear(mlp_hidden_dim, dim, bit=bit),
            IE_LIF(),
            nn.Dropout(drop)
        )

    def forward(self, x):
        """前向传播"""
        # 注意力分支
        x = x + self.drop_path(self.attn(self.norm1(x)))

        # MLP分支
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

class FineGrainedDistillation(nn.Module):
    """细粒度蒸馏损失"""
    def __init__(self):
        super().__init__()

    def compute_similarity_matrix(self, x):
        """计算相似性矩阵"""
        # x: (B, N, C)
        x_norm = F.normalize(x, p=2, dim=-1)
        sim = torch.bmm(x_norm, x_norm.transpose(1, 2))
        return sim

    def forward(self, student_qkv, teacher_qkv):
        """
        计算细粒度蒸馏损失
        Args:
            student_qkv: 学生模型的QKV [(B,N,C), (B,N,C), (B,N,C)]
            teacher_qkv: 教师模型的QKV [(B,N,C), (B,N,C), (B,N,C)]
        Returns:
            蒸馏损失
        """
        loss = 0.0
        for s_feat, t_feat in zip(student_qkv, teacher_qkv):
            s_sim = self.compute_similarity_matrix(s_feat)
            t_sim = self.compute_similarity_matrix(t_feat)
            loss += F.mse_loss(s_sim, t_sim)

        return loss / len(student_qkv)

class QSDTransformer(nn.Module):
    """QSD-Transformer主模型"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., bit=4):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.patch_size = patch_size

        # Patch嵌入
        self.patch_embed = QuantizedConv2d(in_chans, embed_dim, patch_size, patch_size, bit=bit)
        num_patches = (img_size // patch_size) ** 2

        # 位置嵌入
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # Transformer块
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            QSDTransformerBlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], bit=bit
            ) for i in range(depth)
        ])

        # 分类头
        self.norm = nn.LayerNorm(embed_dim)
        self.head = QuantizedLinear(embed_dim, num_classes, bit=bit) if num_classes > 0 else nn.Identity()

        # 细粒度蒸馏
        self.distillation = FineGrainedDistillation()

        # 初始化
        trunc_normal_(self.pos_embed, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (QuantizedLinear, QuantizedConv2d)):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_features(self, x):
        """特征提取"""
        # Patch嵌入
        x = self.patch_embed(x)  # (B, C, H, W)
        B, C, H, W = x.shape
        x = x.reshape(B, C, H*W).transpose(1, 2)  # (B, N, C)

        # 位置嵌入
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # 转换为块格式
        x = x.reshape(B, H, W, C)

        # 通过Transformer块
        for block in self.blocks:
            x = block(x)

        # 全局平均池化
        x = x.reshape(B, H*W, C)
        x = self.norm(x)
        x = x.mean(dim=1)  # 全局平均池化

        return x

    def forward(self, x, teacher_features=None):
        """前向传播"""
        x = self.forward_features(x)
        x = self.head(x)

        if teacher_features is not None and self.training:
            # 计算蒸馏损失
            distill_loss = self.distillation(x, teacher_features)
            return x, distill_loss

        return x

    def set_inference_mode(self):
        """设置推理模式"""
        for module in self.modules():
            if isinstance(module, IE_LIF):
                module.training_mode = False

# 测试代码
if __name__ == '__main__':
    # 创建模型
    model = QSDTransformer(
        img_size=224,
        patch_size=16,
        embed_dim=384,
        depth=6,
        num_heads=6,
        num_classes=1000,
        bit=4
    )

    # 创建测试数据
    batch_size = 2
    x = torch.randn(batch_size, 3, 224, 224)

    # 前向传播
    output = model(x)

    # 设置推理模式
    model.set_inference_mode()
    model.eval()
    with torch.no_grad():
        output_inference = model(x)

    # 打印结果
    print('输入尺寸:', x.size())
    print('训练输出尺寸:', output.size())
    print('推理输出尺寸:', output_inference.size())
    print('参数数量:', sum(p.numel() for p in model.parameters()) / 1e6, 'M')

详细代码 gitcode地址:https://gitcode.com/2301_80107842/research

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐