【论文下饭】Temporal Graph Network for Deep Learning on Dynamic Graphs
文章目录1 介绍2 背景3 Temporal Graph Network3.1 核心模块MemoryMessage Function(msg)Message Aggregator(agg)Memory Updater(mem)Embedding(emb)3.2 训练4 相关工作5 实验5.1 性能表现(实验结果)5.2 模块选择MemoryEmbedding MoudleMessage Aggre
文章目录
综述
Representation Learning for Dynamic Graphs: A Survey
知识点
Transductive\inductive
Paper: Temporal graph network for deep learning on dynamic graphs
Cite: Rossi E, Chamberlain B, Frasca F, et al. Temporal graph networks for deep learning on dynamic graphs[J]. arXiv preprint arXiv:2006.10637, 2020.
中文参考:
内容比较好:TGN:Temporal Graph Networks for Deep Learning on Dynamic Graphs
格式比较好:TGN: TEMPORAL GRAPH NETWORKS FOR DEEP LEARNING ON DYNAMIC GRAPHS论文笔记
1 介绍
图表示学习 取得了一系列的成功。
图 普遍地用于 关系 和 交互 系统的建模,比如说 社交网络、生物网络。
在这些网络上,普遍使用GNN。GNN通过 信息传递机制 聚合邻居的信息,得到该节点的嵌入向量。之后,便可用于节点分类、图分类、边集预测任务上。
大部分在 图上的深度学习算法 都有一个前提假定——图是静态的。
但是,大多数现实生活的交互系统,比如说 社交网络、生物网络,图都是动态的。
通过忽略动态图的时序特征,使用 静态图深度学习方法,也是可以的。但是,这一般是次优解。因为,在某些情况下,模型会忽略掉一些动态图的关键特征。
在动态图上的研究也是今年才兴起的,大多数研究都局限于 离散时间图(discrete-time dynamic graph)。当动态图是连续的(边可以在任何时间出现)、进化的(点可以连续地加入到图钟)时,上面提到的方法都不适合。
直到最近,有很多方法提出 说 支持 连续时间图(continuous-time dynamic graph)。
本文的贡献
- 提出了适用于 连续时间图 的 generic inductive framework Temproal Graph Networks(TGNs)。在本文之前的很多方法,都可以看作是TGNs的一个特例。
- 提出了一个高效的训练策略,使得模型能够从时序数据中实现高效的并行处理。
- 做了很多详细的 消融实验,分析了 本文模型各个组件的性能。
- 本文的模型在许多(both transductive and inductive)任务上取得了SOTA表现,并且 速度比之前的方法快。
2 背景
在静态图上的深度学习
主要讲的就是GNN(这里跳过)
动态图
动态图有两种分类。
-
Discrete-time dynamic graphs 离散时间动态图DTDG
它是一系列的 一段时间内的 动态图的快照(snapshots)。 -
Continuos-time dynamic graphs 连续时间动态图CTDG
它是动态图更加一般性的表示。由一系列的时间组成(timed lists of events)。这些事件包括 边的增加/删除、节点的添加/删除、节点/边的特征变化。
本文在正文部分 使用节点/边的添加 作为例子。(节点/边删除 在附录讨论)
3 Temporal Graph Network
TGN的介绍。encoder-decoder模型。
3.1 核心模块
Memory
对于每个节点 i i i,在时刻 t t t,都有一个向量 s i ( t ) s_i(t) si(t)表示为该节点的记忆单元。它记录了节点 i i i在 [ 0 , t ] [0,t] [0,t]时间内记忆。它代表了节点的“历史”。
有了记忆模块,TGNs就可以记录 每个节点在图中 长期的依赖关系。
- 当新的节点被加入时,它的记忆单元初始化一个零向量。
- 每个事件过后,相关节点的记忆单元就会被更新。
也可以使用全局的记忆单元,来记录整个图的变化,但是为了简单起见,这留作未来的工作。
Message Function(msg)
对于一个 交互事件 e i j ( t ) e_{ij}(t) eij(t),有两个方向的信息。
m i ( t ) = m s g s ( s i ( t − ) , s j ( t − ) , Δ t , e i j ( t ) ) m j ( t ) = m s g d ( s j ( t − ) , s i ( t − ) , Δ t , e i j ( t ) ) \begin{aligned} &m_i(t) = \mathrm{msg_s} (s_i(t^-),s_j(t^-),\Delta t,e_{ij}(t) ) \\ &m_j(t) = \mathrm{msg_d} (s_j(t^-),s_i(t^-),\Delta t,e_{ij}(t)) \end{aligned} mi(t)=msgs(si(t−),sj(t−),Δt,eij(t))mj(t)=msgd(sj(t−),si(t−),Δt,eij(t))
其中, s i ( t − ) s_i(t^-) si(t−)为节点 i i i在 t t t时刻前的记忆单元。 m s g msg msg为可以学习的信息传递函数,比如说MLP。
在本文中,使用 i d e n t i t y ( i d ) identity(id) identity(id)作为信息传递函数 m s g msg msg。(即,简单地对 m s g msg msg的输入进行concate)
Message Aggregator(agg)
在使用批次(batch)处理时,有一些节点的 m s g msg msg会被多次使用。
出于性能考虑,本文提出了 信息聚合机制。
m ˉ i ( t ) = a g g ( m i ( t 1 ) , . . . , m i ( t b ) ) \bar m_i(t) = \mathrm{agg}(m_i(t_1),...,m_i(t_b)) mˉi(t)=agg(mi(t1),...,mi(tb))
其中, t 1 , . . . , t b t_1,...,t_b t1,...,tb为节点 i i i 在相同批次中的时间序列。 a g g \mathrm{agg} agg 是聚合函数,比如说RNN或attention机制。
简单地说,就是把 同一批次中所有节点 i i i的 m s g msg msg聚合到一起。
在本文中,使用
most recent message
只保留最近的信息。mean message
计算所有信息的平均值。
Memory Updater(mem)
对于每个事件涉及到的节点,需要更新其记忆单元:
s i ( t ) = m e m ( m ˉ i ( t ) , s i ( t − ) ) s_i(t) = \mathrm{mem}(\bar m_i(t),s_i(t^-)) si(t)=mem(mˉi(t),si(t−))
其中, m e m \mathrm{mem} mem可以学习的更新函数。比如说,循环神经网络LSTM、GRU。
在本文中,使用 G R U \mathrm{GRU} GRU作为记忆更新函数 m e m \mathrm{mem} mem。
Embedding(emb)
向量嵌入模块 可以生成 每个节点 i i i在 t t t时刻的 时序嵌入向量 z i ( t ) z_i(t) zi(t)。
记忆过期问题
节点 i i i的记忆单元更新,当且仅当 存在事件 包含节点 i i i。当某个节点 i i i的记忆单元长时间得不到更新时,节点 i i i的记忆单元就可以被认为 过期(stale) 了。
比如,在社交网络中,某个用户 长时间 不使用该平台 后,又再次使用。
嵌入向量计算的统一形式如下:
z i ( t ) = e m b ( i , t ) = ∑ j ∈ n i k ( [ 0 , t ] ) h ( s i ( t ) , s j ( t ) , e i j , v i ( t ) , v j ( t ) ) z_i(t) = \mathrm{emb}(i,t) = \sum_{j \in n^k_i([0,t])}h(s_i(t),s_j(t),e_{ij},\bold{v}_i(t),\bold{v}_j(t)) zi(t)=emb(i,t)=j∈nik([0,t])∑h(si(t),sj(t),eij,vi(t),vj(t))
其中, h h h是可以学习的函数。它可以有多种实现方式,比如说:Identity(id)
e m b ( i , t ) = s i ( t ) \mathrm{emb}(i,t) = s_i(t) emb(i,t)=si(t),直接使用节点的记忆单元。
Time projection(time)
e m b ( i , t ) = ( 1 + Δ t w ) ∘ s i ( t ) \mathrm{emb}(i,t) = (1+\Delta t ~\bold{w}) \circ s_i(t) emb(i,t)=(1+Δt w)∘si(t),其中, w \bold{w} w 是可以学习的参数, Δ t \Delta t Δt 是距上一次交互的时间间隔, ∘ \circ ∘是element-wise
向量乘积。(该方法使用于Joide模型中(Kumar etal., 2019))。
Temporal Graph Attention(attn)
Firstly proposed in TGAT(Xu et al., 2020)
L L L 层的图注意力机制,可以利用节点 i i i的 L L L跳的时序邻居信息,计算(节点 i i i)的嵌入向量。
节点 i i i,在 t t t时刻,第 l l l层的输入是 h i ( l − 1 ) ( t ) h_i^{(l-1)}(t) hi(l−1)(t),节点 i i i的邻居表示 { h 1 l − 1 ( t ) , . . . , h N l − 1 ( t ) } \{h_1^{l-1}(t),...,h_N^{l-1}(t) \} {h1l−1(t),...,hNl−1(t)},特征为 e i 1 ( t 1 ) , . . . , e i N ( t N ) e_{i1}(t_1),...,e_{iN}(t_N) ei1(t1),...,eiN(tN)。
注意:因为训练是按批次(batch)进行的,特征 e e e的发生时刻可能不同。
h i ( l ) ( t ) = M L P ( l ) ( h i ( l − 1 ) ( t ) ∥ h ~ i ( l ) ( t ) ) , h ~ i ( l ) ( t ) = M u l t i H e a d A t t e n t i o n ( l ) ( q ( l ) ( t ) , K ( l ) ( t ) , V ( l ) ( t ) ) , q ( l ) ( t ) = h i ( l − 1 ) ( t ) ∥ ϕ ( 0 ) , K ( l ) ( t ) = V ( l ) ( t ) ) = C ( l ) ( t ) , C ( l ) ( t ) = [ h 1 ( l − 1 ) ( t ) ∥ e i 1 ( t 1 ) ∥ ϕ ( t − t 1 ) , . . . , h N ( l − 1 ) ( t ) ∥ e i N ( t N ) ∥ ϕ ( t − t N ) ] \begin{aligned} &\bold{h}_i^{(l)}(t) = \mathrm{MLP}^{(l)}(\bold{h}_i^{(l-1)}(t) \parallel \tilde{h}_i^{(l)}(t)), \\ &\tilde{\bold{h}}_i^{(l)}(t) = \mathrm{MultiHeadAttention}^{(l)}(\bold{q}^{(l)}(t),\bold{K}^{(l)}(t),\bold{V}^{(l)}(t)),\\ &\bold{q}^{(l)}(t) = \bold{h}_i^{(l-1)}(t) \parallel \phi(0),\\ &\bold{K}^{(l)}(t) = \bold{V}^{(l)}(t)) = \bold{C}^{(l)}(t), \\ &\bold{C}^{(l)}(t) = [\bold{h}_1^{(l-1)}(t) \parallel \bold{e}_{i1}(t_1) \parallel \phi(t-t_1),...,\bold{h}_N^{(l-1)}(t) \parallel \bold{e}_{iN}(t_N) \parallel \phi(t-t_N)] \end{aligned} hi(l)(t)=MLP(l)(hi(l−1)(t)∥h~i(l)(t)),h~i(l)(t)=MultiHeadAttention(l)(q(l)(t),K(l)(t),V(l)(t)),q(l)(t)=hi(l−1)(t)∥ϕ(0),K(l)(t)=V(l)(t))=C(l)(t),C(l)(t)=[h1(l−1)(t)∥ei1(t1)∥ϕ(t−t1),...,hN(l−1)(t)∥eiN(tN)∥ϕ(t−tN)]
其中, ϕ \phi ϕ 是一个 通用的时序编码器(generic time encoding), ∥ \parallel ∥ 是concate操作,最后得到的嵌入向量为 z i ( t ) = e m b ( i , t ) = h i ( L ) ( t ) \bold{z}_i(t) = \mathrm{emb}(i,t) = \bold{h}_i^{(L)}(t) zi(t)=emb(i,t)=hi(L)(t)。 q ( l ) ( t ) \bold{q}^{(l)}(t) q(l)(t)是可以是节点 i i i或节点 i i i的 L − 1 L-1 L−1跳邻居。 K ( l ) ( t ) \bold{K}^{(l)}(t) K(l)(t) 和 V ( l ) ( t ) \bold{V}^{(l)}(t) V(l)(t)是节点 i i i的邻居。
简单地说,就是一个多头注意力机制,重点在 C \bold{C} C中,把时序特征一并输入。
特别地,与TGAT中提到的不同的是,在第 0 0 0层时,本文考虑了节点本身的特征 v ( t ) \bold{v}(t) v(t)(node-wise temporal features),即 h j ( 0 ) ( t ) = s j ( t ) + v j ( t ) h_j^{(0)}(t) = s_j(t)+\bold{v}_j(t) hj(0)(t)=sj(t)+vj(t)。这使得模型可以同时利用 现有的记忆 s j ( t ) s_j(t) sj(t)和时序节点特征 v j ( t ) \bold{v}_j(t) vj(t)。
Temporal Graph Sum(sum)
简单、快速的聚合方法。
h i ( l ) ( t ) = W 2 ( l ) ( h i ( l − 1 ) ( t ) ∥ h ~ i ( l ) ( t ) ) , h ~ i ( l ) ( t ) = R e L u ( ∑ j ∈ n i ( [ 0 , t ] ) W 1 ( l ) ( h j ( l − 1 ) ( t ) ∥ e i j ∥ ϕ ( t − t j ) ) ) \begin{aligned} &\bold{h}_i^{(l)}(t) = \bold{W}_2^{(l)}(\bold{h}_i^{(l-1)}(t) \parallel \tilde{\bold{h}}_i^{(l)}(t)),\\ &\tilde{\bold{h}}_i^{(l)}(t) = \mathrm{ReLu}(\sum_{j \in n_i([0,t])} \bold{W}_1^{(l)} (\bold{h}_j^{(l-1)}(t) \parallel \bold{e}_{ij} \parallel \phi(t-t_j))) \end{aligned} hi(l)(t)=W2(l)(hi(l−1)(t)∥h~i(l)(t)),h~i(l)(t)=ReLu(j∈ni([0,t])∑W1(l)(hj(l−1)(t)∥eij∥ϕ(t−tj)))
同样地, ϕ \phi ϕ 是一个 通用的时序编码器(generic time encoding),最后得到的嵌入向量为 z i ( t ) = e m b ( i , t ) = h i ( L ) ( t ) \bold{z}_i(t) = \mathrm{emb}(i,t) = \bold{h}_i^{(L)}(t) zi(t)=emb(i,t)=hi(L)(t)。
图向量嵌入模块 通过聚合邻居的记忆信息,缓和了 (记忆)过期问题,使得TGN可以计算最新的嵌入向量信息。
temporal graph attention
使得模型能够 寻找 包含重要特征和时序信息 的邻居节点。
3.2 训练
TGN可以用于许多任务上,比如说 边集预测(自监督)或者 节点分类(半监督)。
我们使用 连接预测作为例子:提供一系列时间排序的交互,我们的目标是从过去的观察中,预测未来可能出现的交互。
交互(interactions):也就是 边。
之前提到的训练策略在 记忆相关模块中 存在问题——不能直接影响loss,也就是说接收不到梯度。
为了解决这个问题,记忆单元(memory)必须在预测交互之前 更新。但是,这就造成了信息泄露(information leakage)。
为了避免这个问题,本文提出了一个额外的模块
Raw Message Store
,用于存储 batch b b b 的交互信息,原来的Message
模块 用于存储 batch b − 1 b-1 b−1 的交互信息。
这样,通过添加 缓存,解决了信息泄露的问题。
需要注意的是,我们的 batch size
不能选的太大。经过 速度和颗粒度(granularity)的权衡,作者认为batch size=200
is good。
因为,当前的 预测 用到的是 上一个batch的交互信息。如果batch太大(极端地说,整个数据集),所有的预测 用到的都是 初始的零向量记忆单元。
4 相关工作
早期的工作 都集中于DTDGs上。
比如说,
- 聚合 图的快照 然后使用静态的方法
- 把 图的快照 整合成张量并分解它
- 编码 每张快照 产生一系列嵌入向量。
另一条编码DTDGs的主线工作:先在初始的快照上使用 随机游走,然后对于子序列快照修改游走行为。
时空图(spatio-temporal graphs)是动态图的特例。因为 时空图的 拓扑逻辑是固定的。
CTDGs。
- 根据 连续时间(continuos time) 限定 随机游走 的转移概率。
- CTDGs的序列模型。对每个事件 e i j e_{ij} eij使用RNN来更新来更新 源点 和 宿点的表示(representations)。
许多的架构用的都是基于RNN的node-wise
记忆单元。由于缺少GNN的信息聚合机制,使得这些记忆单元可能出现 (记忆)过期问题,同时 它的计算也是十分耗时的。
最新的CTDGs学习模型,都可以看作本文框架TGN的一个特例。
5 实验
数据集
Wikipedia
Reddit
Twitter
任务:edge prediction(预测两个节点未来出现连接的概率)。同时研究了transductive
和inductive
设置下的情况。
在
transductive
任务中,预测的连接 在训练时 出现过;在inductive
任务中,预测的连接在 训练中 没有出现过。
本文使用的解码器是一个简单的MLP。
Baselinesstrong baselines:
CTDNE
Jodie
DyRep
TGAT
GAE
VGAE
DeepWalk
Node2Vec
GAT
GraphSAGE
5.1 性能表现(实验结果)
总结:最好的模型TGN-atten
很强,而且很快(比TGAT
快30倍)。
5.2 模块选择
Memory
比较的模块:
TGN-no-mem 没有使用记忆模块
TGN-attn 最好的模型
现象:
TGN-attn
比TGN-no-mem
慢3倍。TGN-attn
比TGN-no-mem
准确率提升4%。
结论:
- 记忆单元 能够帮助 存储节点的长期信息。
- 采样更多的邻居信息 可以 达到同样(带记忆)的效果。(但花费更多的时间)
Embedding Moudle
比较的模块:
TGN-id(DyRep)
TGN-time(Jodie)
TGN-attn
TGN-sum
现象:
TGN-id
优于TGN-time
。- graph-base的方法(
TGN-attn
,TGN-sum
) 比 graph-less的方法TGN-id
高出一大截。 TGN-attn
仅比TGN-sum
高一点点。
结论:
- 使用图的最近的信息,选择哪些邻居是最关键的,是很重要的影响因素。
Message Aggregator
比较的模块:
TGN-mean
TGN-attn
现象:
TGN-mean
比TGN-attn
好一点。TGN-mean
比TGN-attn
慢3倍。
Number of layers
比较的模块:
TGN-2l
TGN-attn
比较使用到的GNN层数,因为在TGAT中,两层比一层 好了不止10%。
现象:
TGN-2l
仅比TGN-attn
高一点。
结论:
- 由于使用到了记忆单元,
TGN-attn
仅使用1层就可以达到不错的效果。 - 当使用1-hop邻居的记忆单元时,我们 间接地 使用了 比
1-hop
更远的信息。
6 结论
- TGN:通用的 时间连续图 深度学习框架。之前的一些工作,都可以看作是本文框架的一个特例。
- TGN可以做到SOTA。
- 本文对每个模块做了详细的消融实验。记忆模块(可以存储 长期信息),嵌入模块(生成最新的节点嵌入向量)很重要。
更多推荐
所有评论(0)