SAGEConv

论文名称:Inductive Representation Learning on Large Graphs

论文链接:https://arxiv.org/pdf/1706.02216.pdf

现在存在方法具有内在transductive,不能generalize未见到的节点,。我们提出GraphSAGE是基于inductive开发的,对于未见到数据具备泛化能力。在训练的过程中,对邻居的节点采用抽样的方式,而不是对所有的节点进行训练。

GraphSAGE学习的主要方法是收集局部的邻居信息,例如度特征或者邻居节点的属性特征。接下来,首先描述推理算法,假设参数已经学习,如何生成节点的Embedding。然后,我们在讲一下如何通过随机梯度下降法和后向传播算法学习模型的参数。

1、Embedding产生算法(前向传播算法)

假设我们已经学习K 聚合函数(AGGREGATEk,k∈{1,…,K}AGGREGATE_k,k\in\{1,\dots,K\}AGGREGATEk,k{1,,K})的参数和一系列的权重矩阵Wk,∀k∈{1,…,K}W^k,\forall k \in\{1, \ldots, K\}Wk,k{1,,K}, 其中K指搜索的深度,这些参数主要用于前向传播。

在这里插入图片描述

Algorithm 1是在一次迭代、深度搜索并aggregate局部邻居的信息,随着迭代次数的增加,能够获取到更深层次的信息。

G=(V,E)\mathcal{G}=(\mathcal{V},\mathcal{E})G=(V,E)表示整张图,

xv,∀v∈V\text{x}_v, \forall v\in\mathcal{V}xvvV表示节点的特征。

kkk表示当前step下,每个节点v∈Vv\in\mathcal{V}vV 收集邻居节点特征表示huk−1,∀u∈N(v)\text{h}_u^{k-1},\forall u\in\mathcal{N(v)}huk1,uN(v),生成单一的节点表示hN(v)k−1\text{h}_{N(v)}^{k-1}hN(v)k1. 注意本次迭代aggregate取决于前一次迭代的输出。其中bad case k=0k=0k=0是节点的特征输入。将汇总的邻居向量hN(v)k−1\text{h}_{\mathcal{N(v)}}^{k-1}hN(v)k1和当前的节点特征hvk−1h_v^{k-1}hvk1进行拼接, 进行全连接层和非线性激活函数σ\sigmaσ的转换,生成的结果作为下一次迭代的输入。最终输出的特征表示为zv≡hvK,∀v∈V\mathbf{z}_{v} \equiv \mathbf{h}_{v}^{K}, \forall v \in \mathcal{V}zvhvK,vV

在minibatch的设置中,对邻居节点和边进行采样。相对于全部节点的向量的计算,GraphSAGE只是对必要的节点minibatch集合B\mathcal{B}B进行计算。

在这里插入图片描述

主要的思想就是抽样出需要的节点进行计算,Algorithm的Line 2-7描述了抽样的过程。

从Line1-6看出: 每个Bk\mathcal{B}^kBk包含节点v∈Bk+1v\in \mathcal{B}^{k+1}vBk+1的 表示。

Line9-15 描述聚合过程,Nk(u)\mathcal{N}_k(u)Nk(u)采用独立的均匀采样。

Relation to the Weisfeiler-Lehman Isomorphism Test. GraphSAGE的灵感来自同构图检验的经典算法。在Algorithm中,我们 (i)设K=∣V∣K=|V|K=V (ii) 边权重是相等的。(iii) 使用Hash函数作为aggregator。如果两个子图输出{zv,∀v∈V}\{\text{z}_v, \forall v\in\mathcal{V}\}{zv,vV}是相同的,我们认为两个子图是同构的。当然,我们的目标学习节点的表示,不是测试是否同构。

Neighborhood defination 均匀采样固定大小的邻居节点的数量, 即N(v)\mathcal{N(v)}N(v)是固定的,每次迭代均匀采样不同的样本。如果不采样,一个Batch的大小为O(∣V∣)O{(|\mathcal{V}|)}O(V). 采样后, GraphSAGE的复杂度固定在O(∏i=1KSi)O\left(\prod_{i=1}^{K} S_{i}\right)O(i=1KSi), 其中Si,i∈{i,⋯ ,K}S_i, i\in\{i,\cdots,K\}Si,i{i,,K}, KKK用户自定义的。在实际应用中,一般K=2K=2K=2S1⋅S2≤500S_{1} \cdot S_{2} \leq 500S1S2500

2、GraphSAGE参数学习

使用graph-based loss function学习节点的表示,zu,∀u∈V\text{z}_u, \forall u\in\mathcal{V}zu,uV, 学习权重矩阵Wk,∀k∈{1,⋯ ,K}W^k, \forall k\in\{1,\cdots, K\}Wk,k{1,,K}, 采用随机梯度下降的方法。graph-based loss function会使得相近的节点有相同的表示,同时兼顾不同的节点学习表示不相同的。
JG(zu)=−log⁡(σ(zu⊤zv))−Q⋅Evn∼Pn(v)log⁡(σ(−zu⊤zvn))(1) J_{\mathcal{G}}\left(\mathbf{z}_{u}\right)=-\log \left(\sigma\left(\mathbf{z}_{u}^{\top} \mathbf{z}_{v}\right)\right)-Q \cdot \mathbb{E}_{v_{n} \sim P_{n}(v)} \log \left(\sigma\left(-\mathbf{z}_{u}^{\top} \mathbf{z}_{v_{n}}\right)\right)\tag{1} JG(zu)=log(σ(zuzv))QEvnPn(v)log(σ(zuzvn))(1)
其中,节点v和节点u是在固定长度随机游走过程共现的。σ\sigmaσ是激活函数,PnP_nPn是负采样的分布, QQQ是定义负样本的数量, 损失函数输入表示zu\text{z}_uzu包含来自邻居节点的特征。

3、Aggregator Architectures

节点的邻居和文本、图像不一样,他们的无序、对称的,测试如下三个aggregator functions:

Mean aggregator

用以下公式替换Algorithm1中Line4和Line5
hvk←σ(W⋅MEAN⁡({hvk−1}∪{huk−1,∀u∈N(v)})(2) \mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W} \cdot \operatorname{MEAN}\left(\left\{\mathbf{h}_{v}^{k-1}\right\} \cup\left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}\right)\right.\tag{2} hvkσ(WMEAN({hvk1}{huk1,uN(v)})(2)
计算当前节点hvk−1\mathbf{h}_v^{k-1}hvk1和邻居节点拼接起来, 计算均值,这种操作可以将不同深度的节点进行"skip connnection"。

LSTM aggregator

LSTM不具备对称性,简单地将邻居节点处理成无序的序列作为输入。

Pooling aggregator
 AGGREGATE kpool =max⁡({σ(Wpool huik+b),∀ui∈N(v)})(3) \text { AGGREGATE }_{k}^{\text {pool }}=\max \left(\left\{\sigma\left(\mathbf{W}_{\text {pool }} \mathbf{h}_{u_{i}}^{k}+\mathbf{b}\right), \forall u_{i} \in \mathcal{N}(v)\right\}\right)\tag{3}  AGGREGATE kpool =max({σ(Wpool huik+b),uiN(v)})(3)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐