最近在研究Vision Transformer(ViT),发现这个把Transformer用到图像分类上的想法真的很巧妙。正好MindSpore有个完整的教程,就跟着跑了一遍,记录下整个过程。

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区
在这里插入图片描述
在这里插入图片描述

下图展示了ViT的完整架构:从输入图像分割成patches,到Transformer编码器处理,最后通过分类头输出结果。整个流程清晰明了,接下来我们一步步来实现。

在这里插入图片描述

1 环境搭建和数据准备

1.1 环境配置

首先确保本地装好了Python和MindSpore。这个教程建议用GPU跑,CPU会慢得让人怀疑人生。

数据集用的是ImageNet的子集,第一次运行会自动下载:

from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"

path = download(dataset_url, path, kind="zip", replace=True)

下载完后数据结构是这样的:

.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/

1.2 数据预处理

数据预处理这块比较标准,主要是resize、随机裁剪、归一化这些操作:

import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms

data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)

trans_train = [
    transforms.RandomCropDecodeResize(size=224,
                                      scale=(0.08, 1.0),
                                      ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

这里的mean和std是ImageNet的标准值,乘以255是因为MindSpore的数据格式。

2 ViT模型原理解析

2.1 Transformer的核心:Self-Attention

要理解ViT,得先搞懂Transformer的核心机制——Self-Attention。简单来说,就是让模型学会关注输入序列中不同位置之间的关系。

Self-Attention的计算过程:

  1. 输入向量通过三个不同的线性变换得到Q(Query)、K(Key)、V(Value)
  2. 计算Q和K的点积,得到注意力权重
  3. 用这些权重对V进行加权求和

数学公式是这样的:
{ q i = W q ⋅ x i k i = W k ⋅ x i v i = W v ⋅ x i \begin{cases} q_i = W_q \cdot x_i \\ k_i = W_k \cdot x_i \\ v_i = W_v \cdot x_i \end{cases} qi=Wqxiki=Wkxivi=Wvxi

然后计算注意力分数:
a i , j = q i ⋅ k j d a_{i,j} = \frac{q_i \cdot k_j}{\sqrt{d}} ai,j=d qikj

经过Softmax归一化后,得到最终输出:
o u t p u t i = ∑ j softmax ( a i , j ) ⋅ v j output_i = \sum_j \text{softmax}(a_{i,j}) \cdot v_j outputi=jsoftmax(ai,j)vj

在这里插入图片描述

上图详细展示了Self-Attention的计算过程:从输入序列X通过线性变换得到Q、K、V矩阵,然后计算注意力分数,经过Softmax得到权重,最后加权求和得到输出。这个机制让模型能够动态地关注输入序列中的不同部分。

2.2 Multi-Head Attention实现

多头注意力就是把输入分成多个"头",每个头独立计算注意力,最后拼接起来。这样能让模型从不同角度理解输入:

from mindspore import nn, ops

class Attention(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()

        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)

        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)

    def construct(self, x):
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)

        return out

这段代码的关键在于:

  • qkv = self.qkv(x) 一次性生成Q、K、V三个矩阵
  • reshape和transpose操作把数据重新组织成多头的形式
  • 最后把多个头的结果拼接回去

2.3 Feed Forward和残差连接

除了注意力机制,Transformer还需要Feed Forward网络和残差连接:

from typing import Optional

class FeedForward(nn.Cell):
    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 activation: nn.Cell = nn.GELU,
                 keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0-keep_prob)

    def construct(self, x):
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
        return x

class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell

    def construct(self, x):
        return self.cell(x) + x

残差连接很简单,就是把输入直接加到输出上,这样能避免深层网络的梯度消失问题。

2.4 TransformerEncoder的完整实现

把注意力机制、Feed Forward和残差连接组合起来,就是TransformerEncoder:

class TransformerEncoder(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []

        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)

            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)

            layers.append(
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))
                ])
            )
        self.layers = nn.SequentialCell(layers)

    def construct(self, x):
        return self.layers(x)

这里有个细节:ViT把LayerNorm放在了注意力和Feed Forward之前,这和标准Transformer不太一样,但实验证明这样效果更好。

3 ViT的关键创新:图像转序列

在这里插入图片描述

上图展示了ViT处理图像的完整流程:从原始图像分割成patches,经过embedding转换,添加位置编码和CLS token,通过Transformer编码器处理,最后提取CLS token进行分类预测。

3.1 Patch Embedding

ViT最巧妙的地方就是把图像转换成序列。具体做法是把图像切成一个个小块(patch),然后把每个patch拉成一维向量:

class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4

    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 input_channels: int = 3):
        super(PatchEmbedding, self).__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, 
                             kernel_size=patch_size, stride=patch_size, has_bias=True)

    def construct(self, x):
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))
        return x

这里用卷积来实现patch切分,比手工切分更高效。对于224×224的图像,用16×16的patch,最终得到14×14=196个patch。

3.2 位置编码和分类token

图像切成patch后,还需要加上位置信息和分类token:

# 在ViT类的__init__中
self.cls_token = init(init_type=Normal(sigma=1.0),
                      shape=(1, 1, embed_dim),
                      dtype=ms.float32,
                      name='cls',
                      requires_grad=True)

self.pos_embedding = init(init_type=Normal(sigma=1.0),
                          shape=(1, num_patches + 1, embed_dim),
                          dtype=ms.float32,
                          name='pos_embedding',
                          requires_grad=True)

分类token借鉴了BERT的思路,在序列开头加一个特殊token,最后用这个token的输出来做分类。位置编码则告诉模型每个patch在图像中的位置。

3.3 完整的ViT模型

把所有组件组合起来,就是完整的ViT模型:

from mindspore.common.initializer import Normal, initializer
from mindspore import Parameter

def init(init_type, shape, dtype, name, requires_grad):
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)

class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()

        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches

        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)

        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)

        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)

    def construct(self, x):
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding

        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]  # 取分类token的输出
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)

        return x

整个流程就是:图像 → patch embedding → 加上cls token和位置编码 → Transformer编码器 → 分类头。

4 训练和验证实战

4.1 训练配置

训练前需要设置损失函数、优化器等:

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train

# 超参数设置
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()

# 构建模型
network = ViT()

# 加载预训练权重
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

# 学习率调度
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)

# 优化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)

这里用了预训练模型,所以学习率设得比较小。余弦退火调度能让训练更稳定。

4.2 损失函数

用了带标签平滑的交叉熵损失:

class CrossEntropySmooth(LossBase):
    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss

network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)

标签平滑能防止模型过拟合,提高泛化能力。

4.3 开始训练

# 设置检查点
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)

# 初始化模型
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics={"acc"}, amp_level="O0")

# 开始训练
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False)

训练过程中会看到这样的输出:

epoch: 1 step: 125, loss is 1.903618335723877
Train epoch time: 99857.517 ms, per step time: 798.860 ms
epoch: 2 step: 125, loss is 1.448015570640564
Train epoch time: 95555.111 ms, per step time: 764.441 ms

loss在逐渐下降,说明训练正常进行。

在这里插入图片描述

上图展示了ViT模型的训练过程:左侧是损失函数的下降趋势,右侧是准确率的提升曲线,下方表格总结了训练配置和最终结果。可以看到模型在训练过程中稳定收敛,最终达到了不错的性能。

4.4 模型验证

训练完后验证一下效果:

# 验证数据预处理
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)

trans_val = [
    transforms.Decode(),
    transforms.Resize(224 + 32),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)

# 评估指标
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
                'Top_5_Accuracy': train.Top5CategoricalAccuracy()}

if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics=eval_metrics, amp_level="O0")

# 开始验证
result = model.eval(dataset_val)
print(result)

结果显示:

{'Top_1_Accuracy': 0.75, 'Top_5_Accuracy': 0.928}

Top-1准确率75%,Top-5准确率92.8%,效果还不错。

5 推理测试

5.1 推理数据准备

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)

trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_infer = dataset_infer.map(operations=trans_infer,
                                  input_columns=["image"],
                                  num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

5.2 推理和结果可视化

import cv2
import numpy as np
from PIL import Image
from scipy import io

def index2label():
    """获取ImageNet类别标签"""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']
    
    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
    
    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
    
    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping

# 推理
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    mapping = index2label()
    output = {int(label): mapping[int(label)]}
    print(output)

推理结果:

{236: 'Doberman'}

在这里插入图片描述

模型正确识别出了杜宾犬,说明推理效果不错。

6 总结和思考

6.1 ViT的优势

通过这次实践,感受到ViT的几个优势:

  1. 架构简洁:相比CNN的复杂卷积层设计,ViT的架构更加统一和简洁
  2. 可扩展性强:Transformer的并行计算能力让模型可以轻松扩展到更大规模
  3. 迁移能力好:在大数据集上预训练后,可以很好地迁移到下游任务

6.2 实践中的坑

  1. 计算资源要求高:ViT对GPU内存要求比较大,batch size不能设太大
  2. 需要大量数据:相比CNN,ViT更依赖大规模预训练数据
  3. 位置编码很重要:去掉位置编码后性能会明显下降

6.3 代码实现的亮点

MindSpore的实现有几个不错的地方:

  1. 模块化设计:每个组件都封装得很好,便于理解和修改
  2. 自动混合精度:通过amp_level参数可以轻松开启混合精度训练
  3. 灵活的数据处理:数据预处理管道设计得很灵活

整个跑通过程还是比较顺利的,代码质量不错,注释也比较清楚。对于想了解ViT原理和实现的同学来说,这个教程是个不错的起点。

当然,要真正掌握ViT,还需要多读论文,多做实验。这次只是个开始,后面可以尝试在自己的数据集上微调,或者实现一些ViT的变种模型。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐