代码学习 - 提示网络生成
本文学习自Unist的论文实现,源码地址:https://github.com/tsinghua-fib-lab/UniST/tree/main,本内容为个人理解整理。首先是初始化部分,对一下变量进行定义:num_memory(M): 记忆槽数量;memory_dim( C ): 每个槽的维度;memMatrix: [M, C],值矩阵(Value);keyMatrix: [M, C],键矩阵(K
文章目录
前言
本文学习自Unist的论文实现,源码地址:https://github.com/tsinghua-fib-lab/UniST/tree/main,本内容为个人理解整理。
一、提示生成
1.Prompt_ST
代码如下:
class Prompt_ST(nn.Module):
"""
Prompt ST with spatial prompt and temporal prompt
spatial prompt: multiscale convolutional kernels
temporal prompt: closeness and period
"""
def __init__(self, num_memory_spatial, num_memory_temporal, memory_dim, in_channels, conv_num, args=None):
super().__init__()
self.spatial_prompt = Sptial_prompt(num_memory_spatial, memory_dim, in_channels, conv_num, args=args)
self.temporal_prompt = Temporal_prompt(num_memory_temporal, memory_dim, args=args)```
在原函数中的调用代码如下:
self.st_prompt = Prompt_ST(args.num_memory_spatial, args.num_memory_temporal, embed_dim, self.args.his_len, args.conv_num, args=args)
其中num_memory_spatial为默认 512;num_memory_temporal默认为 512;args.his_len默认 6;
conv_num默认 3。
2.Memory
模型的初始化参数:
n u m m e m o r y num_{memory} nummemory:memory 槽位数,记作 M。决定有多少个“记忆原型”;
m e m o r y d i m memory_{dim} memorydim:每个槽位向量维度,记作 C。和输入特征最后一维一致,在Unist中就是C;
a r g s args args:配置对象(可选),这里主要是保存到 self.args,本类内部基本没直接用。
模型的前向参数:
x x x:查询特征,形状 [N, C],其中 N 是样本数,C 等于数据特征维度;
T y p e Type Type:字符串标签参数(比如 ‘closeness’/‘period’),当前实现里不参与计算;
s h a p e shape shape:辅助传参,当前实现里也不参与计算。
这个 Memory 类本质上是一个可学习的“键值记忆库”,给一个输入特征 x(query),它会去一堆可训练的“记忆槽”里按相似度取东西,再把取到的结果返回。
里面有两个关键变量 keyMatrix ,形状为 [M, C],记录每个记忆槽的 key(用于匹配相似度); memMatrix,形状为 [M, C],记录每个记忆槽的 value(真正被取出来加权求和的内容)。
参数与结构定义
首先是初始化部分,对一下变量进行定义:
num_memory(M): 记忆槽数量;
memory_dim( C ): 每个槽的维度;
memMatrix: [M, C],值矩阵(Value);
keyMatrix: [M, C],键矩阵(Key);
x_proj: 线性映射;
顺便初始化权重self.initialize_weights()。
class Memory(nn.Module):
def __init__(self, num_memory, memory_dim, args=None):
super().__init__()
self.args = args
self.num_memory = num_memory
self.memory_dim = memory_dim
self.memMatrix = nn.Parameter(torch.zeros(num_memory, memory_dim)) # M,C
self.keyMatrix = nn.Parameter(torch.zeros(num_memory, memory_dim)) # M,C
self.x_proj = nn.Linear(memory_dim, memory_dim)
self.initialize_weights()
print("model initialized memory")
初始化权重initialize_weights()
用截断正态分布初始化 value memory [memMatrix] 和 key memory [keyMatrix];
截断正态分布:先按正态分布采样,再把超出某个区间(通常是均值两侧若干个标准差,这段代码里 std=0.02)的值丢掉或重采样,只保留中间范围。在 PyTorch 通过 trunc_normal_ 调用,常见是把极端值截掉。从而避免初始化时出现离群大值,让参数尺度更稳定、训练早期更平滑。
然后用self.apply()递归遍历子模块,调用 _init_weights,把线性层按规则初始化(线性层 Xavier、bias=0;LN weight=1、bias=0)。
self.apply(fn) 是 nn.Module 的递归遍历工具。它会把当前模块以及所有子模块都遍历一遍,并对每个模块调用一次 fn(module)。常见用于统一初始化、批量替换属性等
def initialize_weights(self):
torch.nn.init.trunc_normal_(self.memMatrix, std=0.02)
torch.nn.init.trunc_normal_(self.keyMatrix, std=0.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
按层初始化参数函数_init_weights()(用于函数initialize_weights中)
首先检查通过isinstance(m, nn.Linear)检查模块 m 是不是全连接层(Linear)。接着用torch.nn.init.xavier_uniform_(m.weight) 给线性层权重做 Xavier 均匀初始化,让前后层方差更平衡,训练更稳定。如果m有偏置项,则设置为0。
如果m是LayerNorm 层,则把 LayerNorm 的偏置初始化为 0,缩放参数初始化为 1。
torch.nn.init.xavier_uniform_(m.weight) 是在把某个层的权重 m.weight 用 Xavier(也叫 Glorot)均匀分布初始化。从而在前向传播/反向传播时,各层信号的方差不要越传越爆或越传越小,训练更稳、收敛更快(尤其是比较深的网络)。
它具体怎么“随机”?它会把权重填成一个均匀分布: W ∼ U ( − a , a ) W∼U(−a,a) W∼U(−a,a),其中 a = g a i n ⋅ 6 f a n _ i n + f a n _ o u t a=gain·\sqrt{\frac{6}{fan\_in+fan\_out}} a=gain⋅fan_in+fan_out6,fan_in为输入通道/输入特征数(每个神经元接收多少输入)、fan_out为输出通道/输出特征数(每个神经元输出到多少单元)、gain是按激活函数做的缩放系数(默认是 1.0),PyTorch 会根据 m.weight 的形状自动算 fan_in / fan_out(线性层、卷积层都支持)。
常见于 Linear(全连接)、Conv(卷积),加上 tanh / sigmoid 这类激活函数。如果是 ReLU/LeakyReLU,很多时候更常用 Kaiming/He 初始化(kaiming_uniform_),因为它是按 ReLU 的统计特性推出来的,更匹配。
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
前向检索forward()
这段检索本质是一次 Key-Value Memory Attention,把输入特征映射到一组可学习原型上,再从原型库里读出一个增强表示。先把 query 变“适合匹配”的形状,然后计算 query 和所有 keys 的相似度,softmax 变成注意力权重,最后按权重把 memory values 加权求和取出来。公式可写成: o u t = s o f t m a x ( t a n h ( W x ) K T ) V out = softmax(tanh(Wx)K^T)V out=softmax(tanh(Wx)KT)V。具体实现如下:
def forward(self,x,Type='',shape=None):
# dot product
assert x.shape[-1]==self.memMatrix.shape[-1]==self.keyMatrix.shape[-1], "dimension mismatch"
x_query = torch.tanh(self.x_proj(x))
att_weight = F.linear(input=x_query, weight=self.keyMatrix) # [N,C] by [M,C]^T --> [N,M]
att_weight = F.softmax(att_weight, dim=-1) # NxM
out = F.linear(att_weight, self.memMatrix.permute(1, 0)) # [N,M] by [M,C] --> [N,C]
loss_top = 0.0
return dict(out=out, att_weight=att_weight, loss=loss_top)
3.Temporal_prompt
模型的初始化参数:
n u m m e m o r y num_{memory} nummemory:temporal memory 的槽位数,在UniST中为512;
m e m o r y d i m memory_{dim} memorydim:时间 prompt 的特征维度;Transformer 和 memory 都用这个维度,在UniST中,用embed_dim进行实例化;
a r g s args args:整个配置对象。
模型的前向参数:
x c x_c xc:代表近邻时间序列特征,在UniST中,用x_closeness进行实例化,形状为 [ B ∗ H ∗ W , T , C ] [B*H*W, T, C] [B∗H∗W,T,C];
x p x_p xp:代表周期序列特征,在UniST中,用x_period进行实例化,形状为 [ B ∗ T _ p r e d ∗ H ∗ W , P , C ] [B*T\_pred*H*W, P, C] [B∗T_pred∗H∗W,P,C];
该类的作用是把两类时间信息 x_c(closeness)和 x_p(period)各自编码成时间提示向量,再用 memory 做检索增强,输出 hc/hp 给后续 decoder prompt 注入。具体实现如下:
初始化部分,通过Memory类提供一个可学习记忆库,把输入时间特征映射到“原型模式”上,增强泛化;定义单层结构的时间序列编码器encoder_layer = nn.TransformerEncoderLayer(…);用于分别处理 closeness 与 period的两条独立分支变量self.c_encoder 和 self.p_encoder,还有self.args。
class Temporal_prompt(nn.Module):
""" closeness and period
"""
def __init__(self, num_memory, memory_dim, args=None):
super().__init__()
self.temporal_memory = Memory(num_memory, memory_dim, args=args)
encdoer_layer = nn.TransformerEncoderLayer(d_model=memory_dim, nhead=4, dim_feedforward=memory_dim,batch_first = True)
self.c_encoder = nn.TransformerEncoder(encoder_layer=encdoer_layer, num_layers=1)
self.p_encoder = nn.TransformerEncoder(encoder_layer=encdoer_layer, num_layers=1)
self.memory_dim = memory_dim
self.args = args
def forward(self,x_c, x_p):
...
对于forward前向部分,首先对x_c和x_p做 Transformer 编码,再在时间维做平均池化。形状由[N, T, C] 变成 [N, C]。池化的作用是什么?
时间维池化的例子:假设 x_c 形状是 [N,T,C]=[2,3,4],其中一个样本的 3 个时间步是:
t1: [1, 2, 3, 4]
t2: [2, 4, 6, 8]
t3: [3, 6, 9, 12]
mean(dim=1) 是沿时间维取平均:[(1+2+3)/3, (2+4+6)/3, (3+6+9)/3, (4+8+12)/3]= [2, 4, 6, 8]
变换前为[2,3,4],变换后为[2,4],T=3 这个维度就被汇聚掉了,留下每个样本一个时间摘要向量。
# closeness
hc = self.c_encoder(x_c).mean(dim=1)
shape_c = hc.shape
# period
hp = self.p_encoder(x_p).mean(dim=1)
shape_p = hp.shape
接着把刚得到的 closeness 向量 hc、period 向量 hp 送进同一个 memory 模块,做一次基于记忆槽的检索、重构,得到增强后的表示hc_output和hp_output。
最后从 memory 输出里取增强后的向量out、和memory 分支的辅助损失loss,并组成字典返回。返回的向量out(hc和hp)形状为hc 形状是 [ B ∗ H ∗ W , C ] [B*H*W, C] [B∗H∗W,C]、hp 形状是 [ B ∗ T _ p r e d ∗ H ∗ W , C ] [B*T\_pred*H*W, C] [B∗T_pred∗H∗W,C]。这个loss在memory里被定义为0,自始至终都是0,另外loss这个变量也没有用到
从memory可知,记忆的计算是通过自注意力机制计算相似度,然后用这个相似度去加权计算得到。hc和hp就是和所有 memory 槽比较得到的检索向量。
hc_output = self.temporal_memory(hc,Type='closeness',shape=shape_c)
hp_output = self.temporal_memory(hp,Type='period',shape=shape_p)
hc, loss_c = hc_output['out'], hc_output['loss']
hp, loss_p = hp_output['out'], hp_output['loss']
return dict(hc=hc, hp=hp, loss = loss_c+loss_p)
4.Sptial_prompt
更多推荐


所有评论(0)