前言

本文学习自Unist的论文实现,源码地址:https://github.com/tsinghua-fib-lab/UniST/tree/main,本内容为个人理解整理。


一、提示生成

1.Prompt_ST

代码如下:

class Prompt_ST(nn.Module):
    """
    Prompt ST with spatial prompt and temporal prompt
    spatial prompt: multiscale convolutional kernels
    temporal prompt: closeness and period
    """
    def __init__(self, num_memory_spatial, num_memory_temporal, memory_dim, in_channels, conv_num, args=None):
        super().__init__()

        self.spatial_prompt = Sptial_prompt(num_memory_spatial, memory_dim, in_channels, conv_num, args=args)
        self.temporal_prompt = Temporal_prompt(num_memory_temporal, memory_dim, args=args)```

在原函数中的调用代码如下:

self.st_prompt = Prompt_ST(args.num_memory_spatial, args.num_memory_temporal, embed_dim, self.args.his_len, args.conv_num, args=args)

其中num_memory_spatial为默认 512;num_memory_temporal默认为 512;args.his_len默认 6;
conv_num默认 3。

2.Memory

模型的初始化参数:
n u m m e m o r y num_{memory} nummemory:memory 槽位数,记作 M。决定有多少个“记忆原型”;
m e m o r y d i m memory_{dim} memorydim:每个槽位向量维度,记作 C。和输入特征最后一维一致,在Unist中就是C;
a r g s args args:配置对象(可选),这里主要是保存到 self.args,本类内部基本没直接用。

模型的前向参数:
x x x:查询特征,形状 [N, C],其中 N 是样本数,C 等于数据特征维度;
T y p e Type Type:字符串标签参数(比如 ‘closeness’/‘period’),当前实现里不参与计算;
s h a p e shape shape:辅助传参,当前实现里也不参与计算。

这个 Memory 类本质上是一个可学习的“键值记忆库”,给一个输入特征 x(query),它会去一堆可训练的“记忆槽”里按相似度取东西,再把取到的结果返回。

里面有两个关键变量 keyMatrix ,形状为 [M, C],记录每个记忆槽的 key(用于匹配相似度); memMatrix,形状为 [M, C],记录每个记忆槽的 value(真正被取出来加权求和的内容)。

参数与结构定义

首先是初始化部分,对一下变量进行定义:
num_memory(M): 记忆槽数量;
memory_dim( C ): 每个槽的维度;
memMatrix: [M, C],值矩阵(Value);
keyMatrix: [M, C],键矩阵(Key);
x_proj: 线性映射;
顺便初始化权重self.initialize_weights()。

class Memory(nn.Module):
def __init__(self, num_memory, memory_dim, args=None):
        super().__init__()

        self.args = args

        self.num_memory = num_memory
        self.memory_dim = memory_dim

        self.memMatrix = nn.Parameter(torch.zeros(num_memory, memory_dim))  # M,C
        self.keyMatrix = nn.Parameter(torch.zeros(num_memory, memory_dim))  # M,C

        self.x_proj = nn.Linear(memory_dim, memory_dim)
        
        self.initialize_weights()

        print("model initialized memory")

初始化权重initialize_weights()

截断正态分布初始化 value memory [memMatrix] 和 key memory [keyMatrix];

截断正态分布:先按正态分布采样,再把超出某个区间(通常是均值两侧若干个标准差,这段代码里 std=0.02)的值丢掉或重采样,只保留中间范围。在 PyTorch 通过 trunc_normal_ 调用,常见是把极端值截掉。从而避免初始化时出现离群大值,让参数尺度更稳定、训练早期更平滑。

然后用self.apply()递归遍历子模块,调用 _init_weights,把线性层按规则初始化(线性层 Xavier、bias=0;LN weight=1、bias=0)。

self.apply(fn) 是 nn.Module 的递归遍历工具。它会把当前模块以及所有子模块都遍历一遍,并对每个模块调用一次 fn(module)。常见用于统一初始化、批量替换属性等

    def initialize_weights(self):
        torch.nn.init.trunc_normal_(self.memMatrix, std=0.02)
        torch.nn.init.trunc_normal_(self.keyMatrix, std=0.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

按层初始化参数函数_init_weights()(用于函数initialize_weights中)

首先检查通过isinstance(m, nn.Linear)检查模块 m 是不是全连接层(Linear)。接着用torch.nn.init.xavier_uniform_(m.weight) 给线性层权重做 Xavier 均匀初始化,让前后层方差更平衡,训练更稳定。如果m有偏置项,则设置为0。

如果m是LayerNorm 层,则把 LayerNorm 的偏置初始化为 0,缩放参数初始化为 1。

torch.nn.init.xavier_uniform_(m.weight) 是在把某个层的权重 m.weight 用 Xavier(也叫 Glorot)均匀分布初始化。从而在前向传播/反向传播时,各层信号的方差不要越传越爆或越传越小,训练更稳、收敛更快(尤其是比较深的网络)。
它具体怎么“随机”?它会把权重填成一个均匀分布: W ∼ U ( − a , a ) W∼U(−a,a) WU(a,a),其中 a = g a i n ⋅ 6 f a n _ i n + f a n _ o u t a=gain·\sqrt{\frac{6}{fan\_in+fan\_out}} a=gainfan_in+fan_out6 ,fan_in为输入通道/输入特征数(每个神经元接收多少输入)、fan_out为输出通道/输出特征数(每个神经元输出到多少单元)、gain是按激活函数做的缩放系数(默认是 1.0),PyTorch 会根据 m.weight 的形状自动算 fan_in / fan_out(线性层、卷积层都支持)。
常见于 Linear(全连接)、Conv(卷积),加上 tanh / sigmoid 这类激活函数。如果是 ReLU/LeakyReLU,很多时候更常用 Kaiming/He 初始化(kaiming_uniform_),因为它是按 ReLU 的统计特性推出来的,更匹配。

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

前向检索forward()

这段检索本质是一次 Key-Value Memory Attention,把输入特征映射到一组可学习原型上,再从原型库里读出一个增强表示。先把 query 变“适合匹配”的形状,然后计算 query 和所有 keys 的相似度,softmax 变成注意力权重,最后按权重把 memory values 加权求和取出来。公式可写成: o u t = s o f t m a x ( t a n h ( W x ) K T ) V out = softmax(tanh(Wx)K^T)V out=softmax(tanh(Wx)KT)V。具体实现如下:

    def forward(self,x,Type='',shape=None):

        # dot product
        assert x.shape[-1]==self.memMatrix.shape[-1]==self.keyMatrix.shape[-1], "dimension mismatch"
        x_query = torch.tanh(self.x_proj(x))
        att_weight = F.linear(input=x_query, weight=self.keyMatrix)  # [N,C] by [M,C]^T --> [N,M]
        att_weight = F.softmax(att_weight, dim=-1)  # NxM
        out = F.linear(att_weight, self.memMatrix.permute(1, 0))  # [N,M] by [M,C]  --> [N,C]
        loss_top = 0.0

        return dict(out=out, att_weight=att_weight, loss=loss_top)

3.Temporal_prompt

模型的初始化参数:
n u m m e m o r y num_{memory} nummemory:temporal memory 的槽位数,在UniST中为512;
m e m o r y d i m memory_{dim} memorydim:时间 prompt 的特征维度;Transformer 和 memory 都用这个维度,在UniST中,用embed_dim进行实例化;
a r g s args args:整个配置对象。

模型的前向参数:
x c x_c xc:代表近邻时间序列特征,在UniST中,用x_closeness进行实例化,形状为 [ B ∗ H ∗ W , T , C ] [B*H*W, T, C] [BHW,T,C]
x p x_p xp:代表周期序列特征,在UniST中,用x_period进行实例化,形状为 [ B ∗ T _ p r e d ∗ H ∗ W , P , C ] [B*T\_pred*H*W, P, C] [BT_predHW,P,C]

该类的作用是把两类时间信息 x_c(closeness)和 x_p(period)各自编码成时间提示向量,再用 memory 做检索增强,输出 hc/hp 给后续 decoder prompt 注入。具体实现如下:

初始化部分,通过Memory类提供一个可学习记忆库,把输入时间特征映射到“原型模式”上,增强泛化;定义单层结构的时间序列编码器encoder_layer = nn.TransformerEncoderLayer(…);用于分别处理 closeness 与 period的两条独立分支变量self.c_encoder 和 self.p_encoder,还有self.args。

class Temporal_prompt(nn.Module):
    """ closeness and period
    """
    def __init__(self, num_memory, memory_dim, args=None):
        super().__init__()
        self.temporal_memory = Memory(num_memory, memory_dim, args=args)
        encdoer_layer = nn.TransformerEncoderLayer(d_model=memory_dim, nhead=4, dim_feedforward=memory_dim,batch_first = True)
        self.c_encoder = nn.TransformerEncoder(encoder_layer=encdoer_layer, num_layers=1)
        self.p_encoder = nn.TransformerEncoder(encoder_layer=encdoer_layer, num_layers=1)
        self.memory_dim = memory_dim
        self.args = args
        
    def forward(self,x_c, x_p):
    ...

对于forward前向部分,首先对x_c和x_p做 Transformer 编码,再在时间维做平均池化。形状由[N, T, C] 变成 [N, C]。
池化的作用是什么?

时间维池化的例子:假设 x_c 形状是 [N,T,C]=[2,3,4],其中一个样本的 3 个时间步是:
t1: [1, 2, 3, 4]
t2: [2, 4, 6, 8]
t3: [3, 6, 9, 12]
mean(dim=1) 是沿时间维取平均:[(1+2+3)/3, (2+4+6)/3, (3+6+9)/3, (4+8+12)/3]= [2, 4, 6, 8]
变换前为[2,3,4],变换后为[2,4],T=3 这个维度就被汇聚掉了,留下每个样本一个时间摘要向量。

        # closeness
        hc = self.c_encoder(x_c).mean(dim=1)
        shape_c = hc.shape

        # period
        hp = self.p_encoder(x_p).mean(dim=1)
        shape_p = hp.shape

接着把刚得到的 closeness 向量 hc、period 向量 hp 送进同一个 memory 模块,做一次基于记忆槽的检索、重构,得到增强后的表示hc_output和hp_output。

最后从 memory 输出里取增强后的向量out、和memory 分支的辅助损失loss,并组成字典返回。返回的向量out(hc和hp)形状为hc 形状是 [ B ∗ H ∗ W , C ] [B*H*W, C] [BHW,C]、hp 形状是 [ B ∗ T _ p r e d ∗ H ∗ W , C ] [B*T\_pred*H*W, C] [BT_predHW,C]
这个loss在memory里被定义为0,自始至终都是0,另外loss这个变量也没有用到

从memory可知,记忆的计算是通过自注意力机制计算相似度,然后用这个相似度去加权计算得到。hc和hp就是和所有 memory 槽比较得到的检索向量。

        hc_output = self.temporal_memory(hc,Type='closeness',shape=shape_c)
        hp_output = self.temporal_memory(hp,Type='period',shape=shape_p)
        
        hc, loss_c = hc_output['out'], hc_output['loss']
        hp, loss_p = hp_output['out'], hp_output['loss']

        return dict(hc=hc, hp=hp, loss = loss_c+loss_p)

4.Sptial_prompt

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐