华为开源自研AI框架昇思MindSpore应用案例:跑通Vision Transformer图像分类
华为开源自研AI框架昇思MindSpore应用案例:跑通Vision Transformer图像分类
最近在研究Vision Transformer(ViT),发现这个把Transformer用到图像分类上的想法真的很巧妙。正好MindSpore有个完整的教程,就跟着跑了一遍,记录下整个过程。
如果你对MindSpore感兴趣,可以关注昇思MindSpore社区

下图展示了ViT的完整架构:从输入图像分割成patches,到Transformer编码器处理,最后通过分类头输出结果。整个流程清晰明了,接下来我们一步步来实现。

1 环境搭建和数据准备
1.1 环境配置
首先确保本地装好了Python和MindSpore。这个教程建议用GPU跑,CPU会慢得让人怀疑人生。
数据集用的是ImageNet的子集,第一次运行会自动下载:
from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
path = download(dataset_url, path, kind="zip", replace=True)
下载完后数据结构是这样的:
.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/
1.2 数据预处理
数据预处理这块比较标准,主要是resize、随机裁剪、归一化这些操作:
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
trans_train = [
    transforms.RandomCropDecodeResize(size=224,
                                      scale=(0.08, 1.0),
                                      ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
这里的mean和std是ImageNet的标准值,乘以255是因为MindSpore的数据格式。
2 ViT模型原理解析
2.1 Transformer的核心:Self-Attention
要理解ViT,得先搞懂Transformer的核心机制——Self-Attention。简单来说,就是让模型学会关注输入序列中不同位置之间的关系。
Self-Attention的计算过程:
- 输入向量通过三个不同的线性变换得到Q(Query)、K(Key)、V(Value)
- 计算Q和K的点积,得到注意力权重
- 用这些权重对V进行加权求和
数学公式是这样的:
 { q i = W q ⋅ x i k i = W k ⋅ x i v i = W v ⋅ x i \begin{cases} q_i = W_q \cdot x_i \\ k_i = W_k \cdot x_i \\ v_i = W_v \cdot x_i \end{cases} ⎩
              ⎨
              ⎧qi=Wq⋅xiki=Wk⋅xivi=Wv⋅xi
然后计算注意力分数:
 a i , j = q i ⋅ k j d a_{i,j} = \frac{q_i \cdot k_j}{\sqrt{d}} ai,j=dqi⋅kj
经过Softmax归一化后,得到最终输出:
 o u t p u t i = ∑ j softmax ( a i , j ) ⋅ v j output_i = \sum_j \text{softmax}(a_{i,j}) \cdot v_j outputi=j∑softmax(ai,j)⋅vj

上图详细展示了Self-Attention的计算过程:从输入序列X通过线性变换得到Q、K、V矩阵,然后计算注意力分数,经过Softmax得到权重,最后加权求和得到输出。这个机制让模型能够动态地关注输入序列中的不同部分。
2.2 Multi-Head Attention实现
多头注意力就是把输入分成多个"头",每个头独立计算注意力,最后拼接起来。这样能让模型从不同角度理解输入:
from mindspore import nn, ops
class Attention(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)
        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)
    def construct(self, x):
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)
        return out
这段代码的关键在于:
- qkv = self.qkv(x)一次性生成Q、K、V三个矩阵
- reshape和transpose操作把数据重新组织成多头的形式
- 最后把多个头的结果拼接回去
2.3 Feed Forward和残差连接
除了注意力机制,Transformer还需要Feed Forward网络和残差连接:
from typing import Optional
class FeedForward(nn.Cell):
    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 activation: nn.Cell = nn.GELU,
                 keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
    def construct(self, x):
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
        return x
class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell
    def construct(self, x):
        return self.cell(x) + x
残差连接很简单,就是把输入直接加到输出上,这样能避免深层网络的梯度消失问题。
2.4 TransformerEncoder的完整实现
把注意力机制、Feed Forward和残差连接组合起来,就是TransformerEncoder:
class TransformerEncoder(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []
        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)
            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)
            layers.append(
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))
                ])
            )
        self.layers = nn.SequentialCell(layers)
    def construct(self, x):
        return self.layers(x)
这里有个细节:ViT把LayerNorm放在了注意力和Feed Forward之前,这和标准Transformer不太一样,但实验证明这样效果更好。
3 ViT的关键创新:图像转序列

上图展示了ViT处理图像的完整流程:从原始图像分割成patches,经过embedding转换,添加位置编码和CLS token,通过Transformer编码器处理,最后提取CLS token进行分类预测。
3.1 Patch Embedding
ViT最巧妙的地方就是把图像转换成序列。具体做法是把图像切成一个个小块(patch),然后把每个patch拉成一维向量:
class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4
    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 input_channels: int = 3):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, 
                             kernel_size=patch_size, stride=patch_size, has_bias=True)
    def construct(self, x):
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))
        return x
这里用卷积来实现patch切分,比手工切分更高效。对于224×224的图像,用16×16的patch,最终得到14×14=196个patch。
3.2 位置编码和分类token
图像切成patch后,还需要加上位置信息和分类token:
# 在ViT类的__init__中
self.cls_token = init(init_type=Normal(sigma=1.0),
                      shape=(1, 1, embed_dim),
                      dtype=ms.float32,
                      name='cls',
                      requires_grad=True)
self.pos_embedding = init(init_type=Normal(sigma=1.0),
                          shape=(1, num_patches + 1, embed_dim),
                          dtype=ms.float32,
                          name='pos_embedding',
                          requires_grad=True)
分类token借鉴了BERT的思路,在序列开头加一个特殊token,最后用这个token的输出来做分类。位置编码则告诉模型每个patch在图像中的位置。
3.3 完整的ViT模型
把所有组件组合起来,就是完整的ViT模型:
from mindspore.common.initializer import Normal, initializer
from mindspore import Parameter
def init(init_type, shape, dtype, name, requires_grad):
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)
class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches
        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)
        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)
        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)
    def construct(self, x):
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding
        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]  # 取分类token的输出
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)
        return x
整个流程就是:图像 → patch embedding → 加上cls token和位置编码 → Transformer编码器 → 分类头。
4 训练和验证实战
4.1 训练配置
训练前需要设置损失函数、优化器等:
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
# 超参数设置
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
# 构建模型
network = ViT()
# 加载预训练权重
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
# 学习率调度
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)
# 优化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
这里用了预训练模型,所以学习率设得比较小。余弦退火调度能让训练更稳定。
4.2 损失函数
用了带标签平滑的交叉熵损失:
class CrossEntropySmooth(LossBase):
    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss
network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)
标签平滑能防止模型过拟合,提高泛化能力。
4.3 开始训练
# 设置检查点
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
# 初始化模型
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics={"acc"}, amp_level="O0")
# 开始训练
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False)
训练过程中会看到这样的输出:
epoch: 1 step: 125, loss is 1.903618335723877
Train epoch time: 99857.517 ms, per step time: 798.860 ms
epoch: 2 step: 125, loss is 1.448015570640564
Train epoch time: 95555.111 ms, per step time: 764.441 ms
loss在逐渐下降,说明训练正常进行。

上图展示了ViT模型的训练过程:左侧是损失函数的下降趋势,右侧是准确率的提升曲线,下方表格总结了训练配置和最终结果。可以看到模型在训练过程中稳定收敛,最终达到了不错的性能。
4.4 模型验证
训练完后验证一下效果:
# 验证数据预处理
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
trans_val = [
    transforms.Decode(),
    transforms.Resize(224 + 32),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
# 评估指标
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
                'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics=eval_metrics, amp_level="O0")
# 开始验证
result = model.eval(dataset_val)
print(result)
结果显示:
{'Top_1_Accuracy': 0.75, 'Top_5_Accuracy': 0.928}
Top-1准确率75%,Top-5准确率92.8%,效果还不错。
5 推理测试
5.1 推理数据准备
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
dataset_infer = dataset_infer.map(operations=trans_infer,
                                  input_columns=["image"],
                                  num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)
5.2 推理和结果可视化
import cv2
import numpy as np
from PIL import Image
from scipy import io
def index2label():
    """获取ImageNet类别标签"""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']
    
    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
    
    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
    
    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping
# 推理
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    mapping = index2label()
    output = {int(label): mapping[int(label)]}
    print(output)
推理结果:
{236: 'Doberman'}

模型正确识别出了杜宾犬,说明推理效果不错。
6 总结和思考
6.1 ViT的优势
通过这次实践,感受到ViT的几个优势:
- 架构简洁:相比CNN的复杂卷积层设计,ViT的架构更加统一和简洁
- 可扩展性强:Transformer的并行计算能力让模型可以轻松扩展到更大规模
- 迁移能力好:在大数据集上预训练后,可以很好地迁移到下游任务
6.2 实践中的坑
- 计算资源要求高:ViT对GPU内存要求比较大,batch size不能设太大
- 需要大量数据:相比CNN,ViT更依赖大规模预训练数据
- 位置编码很重要:去掉位置编码后性能会明显下降
6.3 代码实现的亮点
MindSpore的实现有几个不错的地方:
- 模块化设计:每个组件都封装得很好,便于理解和修改
- 自动混合精度:通过amp_level参数可以轻松开启混合精度训练
- 灵活的数据处理:数据预处理管道设计得很灵活
整个跑通过程还是比较顺利的,代码质量不错,注释也比较清楚。对于想了解ViT原理和实现的同学来说,这个教程是个不错的起点。
当然,要真正掌握ViT,还需要多读论文,多做实验。这次只是个开始,后面可以尝试在自己的数据集上微调,或者实现一些ViT的变种模型。
更多推荐
 
 


所有评论(0)