学习日记17:GNNExplainer
GNN处理图数据实力很强,但是目前还没有能解释GNN预测结果的方法。文章提出了一个可以对任何GNN模型的预测结果进行解释的模型GNNExplainer,它可以选出对单个实例做出判断重要的节点和特征,并可以扩展到为整个实例类生成一致且简洁的解释。模型的基本想法是最大化GNN的预测和可能的子图结构分布之间的互信息。尽管GNN具有优势,但它们缺乏透明度,因为它们不容易对其预测做出人类可以理解的解释。目前
摘要
GNN处理图数据实力很强,但是目前还没有能解释GNN预测结果的方法。文章提出了一个可以对任何GNN模型的预测结果进行解释的模型GNNExplainer,它可以选出对单个实例做出判断重要的节点和特征,并可以扩展到为整个实例类生成一致且简洁的解释。模型的基本想法是最大化GNN的预测和可能的子图结构分布之间的互信息。
介绍
尽管GNN具有优势,但它们缺乏透明度,因为它们不容易对其预测做出人类可以理解的解释。目前主流的神经网络解释方法主要有两种:1)代理模型近似法:用简单、可解释的 “代理模型”(如线性模型、决策规则)局部近似复杂神经网络的预测行为,再通过分析代理模型反推原模型的决策依据。2)计算过程分析法:直接对神经网络的内部计算或输入特征进行溯源,识别对预测有重要影响的元素。但这两种方法都不适合GNN,因为他们并没有将图的结构考虑进去,而这正是GNN成功的关键。
GNNExplainer的输入是一个训练好的GNN模型和一个预测,输出的结果是一系列对这个预测最有影响的节点和节点特征。
关于图神经网络的公式化解释
图神经网络的步骤为:1)根据节点之间的连接状态,计算每一个节点对的信息。2)对于每一个节点,它聚合该节点的邻居的所有信息。3)最后,根据节点汇聚的信息和节点当前状态计算出节点的下一个状态。
对GNN的预测结果进行解释的一个关键是找到节点的计算图G(C)(也就是预测这个节点需要哪些节点的信息),使用A(C)这个只含0,1的矩阵来表示计算图。使用计算图以及他们的节点特征来计算给定实例的预测结果。也就是说,在给定GNN模型后,只需要考虑模型的结构信息和节点特征。形式上,GNNExplainer生成预测结果y的解释为只需G(s)和X(s),s表示子图。
GNNExplainer
单实例说明
模型的目标是从计算图中选出一些节点,从特征集中选出一些特征,这些信息在预测结果中起到重要作用。我们使用互信息MI来形式化重要性的概念,并将GNNEXPLAINER表示为以下优化框架:
上式中,H(Y)表示预测结果的不确定性,相减的部分表示在不知道任何情况下对y的预测的不确定性减去知道了G和X后预测的不确定性。在模型确定的情况下H(Y)固定,即最大化互信息就是最小化知道了G和X后预测的不确定性。
确定计算图的方法就是将一个节点从节点从计算图中移除,如果对y的预测出现大幅下降那么这个节点就应该被归入计算图中(边也是如此计算)。同时也要规定计算图节点的最大数量,这意味着GNNEXPLAINER的目标是通过获取与预测提供最高互信息的边来消除G(C)的噪声。
GNNEXPLAINER的优化框架:直接优化子图不可行,因为一个图的子图可能有成千上万个,为了解决这个问题,将上面提到的AC转为AS,与AC不同,AS矩阵中不再规定必须是0或1,而是表达这个边重要信息程度的权重,保持,可以防止子图中出现原图中不存在的边,通过这个方法就可以使用梯度下降解决问题。在计算中一般采用子图的分布来计算。
在凸性假设下,Jensen不等式给出了如下上界:
在实际应用中,由于神经网络的复杂性,凸性假设不成立。然而,在实验中,我们发现用正则化最小化这个目标通常会导致与高质量解释相对应的局部最小值。
通过掩码学习生成分数邻接矩阵,首先生成一个掩膜矩阵M,将其经过sigmoid函数后保证其数值在0,1之间后与AC相乘得到AS,之后通过梯度下降优化M,优化后对AS进行阈值处理得到最终GS。
图结构和节点特征信息的联合学习
为了识别到重要的特征,模型同时训练了一个特征掩码F,F只有0,1来决定那哪些特征重要。引入F后,目标函数表示为
二进制 F 无法直接求导。为此,通过蒙特卡洛估计(从节点特征的经验边际分布中采样)和重参数化技巧如,将离散掩码转化为可微分的连续变量,实现梯度反向传播。
此外,加入一些重要正则化,如:元素 - wise 熵正则化:鼓励特征掩码 F 向离散的 0/1 值靠拢(减少 0.5 左右的模糊权重),提升特征重要性的区分度;掩码元素和正则化: penalize 掩码 F 中 “1” 的数量,避免保留过多特征,确保解释的紧凑性;计算图有效性约束,需保证神经消息能沿 GS 流向目标节点,否则即使某条边或特征单独重要,若无法参与消息传递,也会被排除,最终使 GS 倾向于成为连通子图 。
通过图形原型实现多实例解释
实现多实例解释包含两个阶段:
首先是选取参考点,该节点需能代表类的典型特征。选择方式通常为:计算类 c 所有节点的嵌入均值,选取嵌入最接近均值的节点作为 vc;之后将同属于c类的节点与vc对其,不管是特征维度还是子图结构。寻找子图间节点的对应关系,使不同子图的结构与特征维度可相互匹配 ;
对齐之后,采用基于中位数的聚合策略生成原型图A,可直观展示同类实例共有的关键结构。
更多推荐
所有评论(0)