一、GCN的原理

简单,也有很多博客在说明!
链接1:https://arxiv.org/abs/1609.02907
链接2:https://mp.weixin.qq.com/s/DJAimuhrXIXjAqm2dciTXg

二、GCN的层代码

import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
class GraphConvolution(Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

解释说明:

  1. class GraphConvolution(Module):继承Module类。

  2. class GraphConvolution(Module)中有两个恒常在的函数:__init__()用于初始化参数或者模块等;forward()函数属于输入变量并做运算。

  3. def __init__(self, in_features, out_features, bias=True)这个函数中:

    • super(GraphConvolution, self).__init__():是按照 GraphConvolution的父类Module的初始化方式进行初始化。
    • self.in_features = in_features:用来定义初始化变量,可以在整个class的任意一个函数内部使用。
    • self.weight = Parameter(torch.FloatTensor(in_features, out_features)):定义新的初始化变量。模型中的参数,它是Parameter()类,也是定义GCN的核心操作之一。
      在这里插入图片描述
  4. forward(self, input, adj)函数中输入变量input,adj.

    • support = torch.mm(input, self.weight) 是矩阵乘法,input * self.weight.注意到torch.mm使用范围仅限于二维矩阵。当存在batch变量的时候,也就是infut.shape=[B, N, F]三维形状的时候不使用。建议改为torch.matmul.
    • output = torch.spmm(adj, support)也是矩阵乘法。adj是我们的矩阵输入变量,具有N*N个元素,通常情况下采用稀疏矩阵来保存。spmm是稀疏矩阵的乘法:
      支持 sparse 在前,dense 在后的矩阵乘法
      两个sparse相乘或者dense在前的乘法不支持,
      当然两个dense矩阵相乘是支持的.
      mm是二维矩阵的乘法,不适合用于三维矩阵。
  5. reset_parameters(self)是参数初始化

    • self.weight.size(1)是weight的形状(in_features, out_features)中的out_features
    • math.sqrt(4)=2.0是返回平方根
    • self.weight.data.uniform_(-stdv, stdv):是指weight.data按照均匀分布,上限为-stdv,下限位stdv.
    • 此外对weight的数据初始化方法还有另外一种:init.kaiming_uniform_(self.weight)
  6. __repr__(self)返回该clas的一些介绍。比如
    在这里插入图片描述

三、GCN的搭建

import torch.nn as nn
import torch.nn.functional as F
from pygcn.layers import GraphConvolution
class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN, self).__init__()
        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)
        self.dropout = dropout
    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        return F.log_softmax(x, dim=1)
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐