自然语言处理NLP——图神经网络与图注意力模型(GNN、GCN、GAT)

news/2024/4/29 22:22:23/文章来源:https://blog.csdn.net/weixin_51426083/article/details/128340275

目录

系列文章目录

一、图神经网络

1.图与图嵌入

2.GNN动机

2.1 CNN的缺陷与非结构性数据

2.2 图嵌入的缺陷

3.GNN详解

3.1 GNN简介

3.2 GNN模型

3.3 GNN框架

3.4 GNN局限与优化

二、图卷积神经网络

1.卷积

2.GCN详解

2.1 GCN动机

2.2 GCN简介

2.3 GCN思想与模型

2.4 GCN核心公式解析

2.5 GCN优势与局限

三、图注意力网络

1.Attention机制

2.GAT简介

3.核心公式解析

3.1 图注意力层的输入输出

3.2 特征提取与注意力机制

3.3 输出特征 

3.4 multi-head attention 

4.论文结果与分析

4.1 数据集与比较方法

4.2 实验模型设置

4.3 实验结果与分析

4.4 论文结论补充

四、GAT复现

1.Tensorflow GAT(代码解析)

1.1 代码结构与参数设置

1.2 数据加载与特征预处理

1.3 GAT模型

1.4 GAT网络

1.5 训练与损失函数

2.Pytorch GAT(代码解析)

3.代码实践与结果分析

五、GAT补充

1.GAT的推广

2.GAT优化

2.1 论文简介

2.2 动机

2.3 方法论 

2.4 优化效果

六、实验结论与参考

1.实验结论

2.参考资料


系列文章目录

本系列博客重点在自然语言处理NLP的概念原理与代码实践(有问题欢迎在评论区讨论指出,或直接私信联系我)。

第一章 自然语言处理NLP——GSDMM用于短文本聚类_@李忆如的博客-CSDN博客

第二章 自然语言处理NLP——图神经网络与图注意力模型(GNN、GCN、GAT)


梗概

   本篇博客从图神经网络(GNN)的动机与模型,图卷积网络(GCN)的模型详解与公式推导引入,重点介绍图注意力网络(GAT)的目标函数推导,模型解析,并使用不同框架复现GAT论文实验,对比分析结论与论文结果,并在最后补充一定的GAT推广与优化(内附Python代码与数据集)。


一、图神经网络

    由于GAT本质上是一种图神经网络,故在本部分先做图神经网络的引入。

    图神经网络(Graph Neural Network,GNN)是指使用神经网络来学习图结构数据,提取和发掘图结构数据中的特征和模式,满足聚类、分类、预测、分割、生成等图学习任务需求的算法总称,本章针对图神经网络的相关概念、动机、模型、应用做介绍。

1.图与图嵌入

    图是一种数据结构,常见的图结构包括节点和边,在机器学习(深度学习)中,部分图存在属性、权重等,且多种真实数据可用图(矩阵构图)进行表示,样例如图1所示,GNN是深度学习在图结构上的一个分支

图1 数据的图表示样例

    图嵌入是一种将图数据(通常为高维稠密的矩阵)映射为低微稠密向量的过程,能够很好地解决图数据难以高效输入机器学习算法的问题,样例如图2所示。如果有更多的信息被表示出来,那么下游的任务将会获得更好的表现。在嵌入的过程中存在着一种共识:向量空间中保持连接的节点彼此靠近

图2 图嵌入样例

    图是一种易于理解的表示形式,除此之外图嵌入的优势总结于表1

表1 图嵌入的优势

1.在graph上直接进行机器学习具有一定的局限性

2.图嵌入能够压缩数据

3.向量计算比直接在图上操作更加的简单、快捷。

2.GNN动机

    本部分针对为什么需要GNN以及相关方法的不足做简介。

2.1 CNN的缺陷与非结构性数据

    CNN的核心特点在于:局部连接,权重共享和多层叠加,这些特点在图问题中同样非常适用,因为图结构是最典型的局部连接结构,其次,共享权重可以减少计算量,另外,多层结构是处理分级模式的关键。

    但是,CNN只能在欧几里得数据,比如二维图片和一维文本数据上进行处理,而这些数据只是图结构的特例,样例如图3所示。对于一般的图结构,CNN很难使用(效果不佳)

图3 不同空间图样例

    详细地说,现实世界中并不是所有的事物都可以表示成一个序列或者一个网格,例如社交网络、知识图谱、复杂的文件系统等,也就是说很多事物都是非结构化的,样例如图4所示:

图4 非结构化数据样例

    分析:由图4所示,相比于简单的文本和图像,这种网络类型的非结构化的数据非常复杂,处理它的难点总结于表2:

表2 非结构化数据处理难点

1.图的大小是任意的,拓扑结构复杂,没有像图像一样的空间局部性

2. 图没有固定的节点顺序,或者说没有一个参考节点

3. 图经常是动态图,而且包含多模态的特征

    那么对于这类数据我们该如何建模呢?能否将深度学习进行扩展使得能够建模该类数据呢?这些问题促使了图神经网络的出现与发展。

2.2 图嵌入的缺陷

    图嵌入大致可以划分为三个类别:矩阵分解、随机游走和深度学习方法。常见模型有DeepWalk,Node2Vec等,然而,这些方法方法有两种严重的缺点,首先就是节点编码中权重未共享,导致权重数量随着节点增多而线性增大,另外就是直接嵌入方法缺乏泛化能力,意味着无法处理动态图以及泛化到新的图。GNN与图嵌入对比如图5所示:

图5 GNN与图嵌入对比

3.GNN详解

3.1 GNN简介

    图神经网络是一种特殊的图表示方法,使用神经网络来对图节点进行编码,将图结构embedding为计算机可处理的向量矩阵,与传统NN比较,在节点、边、推理三方面都有优化,总结如下:

(1)节点

①CNN和RNN等都需要节点的特征按照一定的顺序进行排列。

②但对于图结构,并没有天然的顺序。所以,GNN采用*在每个节点上分别传播的方式进行学习,由此忽略了节点的顺序,相当于GNN的输出会随着输入的不同而不同。

(2)边(图结构的边表示节点之间的依存关系)

①传统的神经网络不是显式地表达中这种依存关系,而是通过不同节点特征来间接地表达节点之间的关系,这些依赖信息只是作为节点的特征。

②GNN 可以通过图形结构进行传播,而不是将其作为节点特征的一部分,通过邻居节点的加权求和来更新节点的隐藏状态。

(3)推理

①推理是高级人工智能的一个非常重要的研究课题,人脑中的推理过程几乎都基于从日常经验中提取的图形。标准神经网络已经显示出通过学习数据分布来生成合成图像和文档的能力,同时它们仍然无法从大型实验数据中学习推理图。然而,GNN 探索从场景图片和故事文档等非结构性数据生成图形,可以成为进一步高级 AI 的强大神经模型。

3.2 GNN模型

    对于图神经网络模型,基于不同方法有不同的构成,在此做简述。相比较于神经网络最基本的网络结构全连接层(MLP),特征矩阵乘以权重矩阵,图神经网络多了一个邻接矩阵。计算形式很简单,三个矩阵相乘再加上一个非线性变换,如图6所示:

图6 GNN计算形式

    因此一个比较常见的图神经网络的应用模式如下图7所示,输入一个图,经过多层图卷积等各种操作以及激活函数,最终得到各个节点的表示,以便于进行节点分类、链接预测、图与子图的生成等等任务。

图7 常见图神经网络应用模式

3.3 GNN框架

图8 GNN流程图

    对于GNN中f和g的参数的学习,常使用目标信息来进行监督学习,可以将损失函数定义为式6:

式6 GNN损失函数

    其中,p表示监督节点的数目,tioi分别表示节点的真实值和预测值。损失函数的学习基于梯度下降策略,步骤总结如表3:

表3  损失函数学习流程

1、状态迭代更新:状态hvt按照式1更新T轮,直到接近式3的定点解

2、计算权重W的梯度:权重W的梯度从loss计算得到

3、更新权重:根据2中的梯度更新权重W

3.4 GNN局限与优化

    结合GNN原理、模型、框架分析,总结经典GNN的局限如表4所示:

表4 经典GNN局限

1.对不动点使用迭代的方法来更新节点的隐藏状态,效率不高。

2.原始GNN 在迭代中使用相同参数,模型难以学习到更加深的特征表达。而且,节点隐藏层的更新是顺序流程。

3.一些边上可能会存在某些信息特征不能被有效地考虑进去。此外,如何学习边的隐藏状态也是一个重要问题。

4.如果我们需要学习节点的向量表示而不是图的表示,则不适合使用固定点,因为固定点中的表示分布将在值上非常平滑并且用于区分每个节点的信息量较少。

    根据表4描述,经典GNN还存在局限,故之后出现了许多基于图神经网络的算法,主要是从图类型、传播类型和训练方法三个方面来对图神经网络进行优化,三种类型的各种GNN变体总结如图9所示:

图9 GNN三类优化算法总结

二、图卷积神经网络

    图卷积神经网络(GCN)为图神经网络的“开山之作”,它首次将图像处理中的卷积操作简单的用到图结构数据处理中来。由于GAT是在GCN上做了几大优化,实现过程有类似,故在本部分先做图卷积神经网络引入。

    Tips:GCN推导涉及大量数学理论,本部分仅作核心部分介绍。

1.卷积

    在泛函分析中,卷积是通过两个函数f和g生成第三个函数的一种数学运算,其本质是一种特殊的积分变换,表征函数f与g经过翻转和平移的重叠部分函数值乘积对重叠长度的积分。对于图像(数据)处理,实质上卷积是对信号进行滤波,想法来自于图像,之后引进到图中。然而,当图像有固定的结构时,图就复杂得多。以CNN为例,从图像到图形的卷积思想如图10所示:

图10 图像到图形的卷积思想样例

2.GCN详解

2.1 GCN动机

    GCN动机与GNN动机类似。对于CNN与RNN系列的神经网络算法,对于欧式空间的数据做图像识别、自然语言处理等任务可以取得不错的效果。但现实生活中,其实有很多很多不规则的数据结构,典型的就是图结构,或称拓扑结构,如社交网络、化学分子结构、知识图谱等等,样例如图4所示。

    图的结构一般来说是十分不规则的,可以认为是无限维的一种数据,所以它没有平移不变性。每一个节点的周围结构可能都是独一无二的,这种结构的数据,就让传统的CNN、RNN瞬间失效。为了处理这类数据,涌现出了许多方法,GCN就是其中一种经典方法。

2.2 GCN简介

    类似CNN,GCN是一种卷积神经网络,它可以直接在图(对象)上工作,并利用图的结构信息。GCN精妙地设计了一种从图数据中提取特征的方法,从而我们可以使用这些特征去对图数据进行:节点分类、图分类、边预测,还可以顺便得到图的嵌入表示。其中仅有一小部分节点有标签(半监督学习),其中,在图上进行节点分类的样例如图11所示:

图11 图节点分类样例

2.3 GCN思想与模型

    经典GCN为半监督图卷积神经网络,基本思想为对于每个节点,我们从它的所有邻居节点处(包括自身)获取其特征信息。假设我们使用average()函数。我们将对所有的节点进行同样的操作。最后,我们将这些计算得到的平均值输入到神经网络中。

    如图12,我们有一个引文网络的简单实例。其中每个节点代表一篇研究论文,同时边代表的是引文。我们在这里有一个预处理步骤。在这里我们不使用原始论文作为特征,而是将论文转换成向量(通过使用NLP嵌入,例如tf-idf)。NLP嵌入,例如TF-IDF)。

图12 GCN用于引文网络简单示例

    GCN的主要思想:我们以绿色节点为例。首先,我们取其所有邻居节点的平均值,包括自身节点。然后,将平均值通过神经网络。请注意,在GCN中,我们仅仅使用一个全连接层。在这个例子中,我们得到2维向量作为输出(全连接层的2个节点)如图13所示:

图13 GCN设计样例

    在实际操作中,我们可以使用比average函数更复杂的聚合函数。我们还可以将更多的层叠加在一起,以获得更深的GCN。其中每一层的输出会被视为下一层的输入。

    因此,根据GCN的简介与思想,GCN的模型如图14所示:

图14 GCN模型

    分析:如图14可见,GCN虽然在隐藏层进行了复杂的数学推导与计算,但输入输出是简单规则的。对于GCN而言,一个拥有C个input channel的graph作为输入,经过中间的hidden layers,得到F个output channel的输出。

2.4 GCN核心公式解析

(1)层与层的传播

    假设我们有一批图数据,其中有N个节点,每个节点都有自己的特征,我们设这些节点的特征组成一个N×d维的矩阵X,然后各个节点之间的关系也会形成一个N×N维的矩阵A,也称为邻接矩阵。X和A便是我们模型的输入。

(2)分类

    经过上述核心公式的求解,GCN最终得到的结果是一个经过l层特征加强后得到的Z=H^(l)各个节点的特征向量。也就是说,通过若干层GCN后每个节点的特征从X变成了Z^N x C,其中C为待分类的类别数量。

(3)训练与参数更新 

'''-------------------------------------------------------------------------------------''''''tf.train.AdamOptimizer利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。Adam的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳.'''# self.lr为事先设置好的梯度下降中的学习率self.optimizer = tf.train.AdamOptimizer(self.lr)'''由tf源代码可以知道optimizer.minimize()实际上包含了两个步骤,即optimizer.compute_gradients和optimizer.apply_gradients,前者用于计算梯度,后者用于使用计算得到的梯度来更新对应的变量。''''''computer_gradients(loss, val_list):●loss: 需要被优化的Tensor;这里的loss为self.loss+self.l2最终返回的是元组列表,即[(gradient, variable),...]。例:x = 50, w = 10, y = x*w;结果是[(50,10),(10,50)]列表中第一个元组中第一个元素是y对w求导的结果,第二个元素是w。列表中第二个元组中第一个元素是y对x求导的结果,第二个元素是x。'''# self.loss是通过tf.losses.softmax_cross_entropy计算得到的损失函数的Tensor# l2是一个正则化项gradients = self.optimizer.compute_gradients(self.loss+self.l2)'''self.optimizer.apply_gradients的作用是将compute_gradients()返回的值作为输入参数对变量进行更新。使用tf.clip_by_value来修正梯度:输入一个张量grad,把grad中的每一个元素的值都压缩在-5和5之间。小于-5的让它等于-5,大于5的元素的值等于5。'''capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gradients if grad is not None]self.train_op = self.optimizer.apply_gradients(capped_gradients)'''那为什么minimize()会分开两个步骤呢?原因是因为在某些情况下我们需要对梯度做一定的修正,例如为了防止梯度消失(gradient vanishing)或者梯度爆炸(gradient explosion),我们需要事先干预一下以免程序出现Nan的尴尬情况;有的时候也许我们需要给计算得到的梯度乘以一个权重或者其他乱七八糟的原因,所以才分开了两个步骤。''''''多次执行self.train_op后,即可训练成功''''''---------------------------------------------------------------------------------------'''

2.5 GCN优势与局限

    结合GCN原理、模型、框架分析,总结经典GNN的局限如表5所示:

表5 GCN的优势与局限

    对于GCN存在的不足,后续有许多优化算法,GAT就是其中之一。

三、图注意力网络

1.Attention机制

    GAT为图注意力网络,关键机制为Attention机制,故在此先对Attention机制做一定引入与解析。

    Attention机制的中文名叫“注意力机制”,它的主要作用是让神经网络把“注意力”放在一部分输入上,即:区分输入的不同部分对输出的影响。这里,我们从增强字/词的语义表示这一角度来理解一下Attention机制。

    我们知道,一个字/词在一篇文本中表达的意思通常与它的上下文有关。比如:光看“鹄”字,我们可能会觉得很陌生(甚至连读音是什么都不记得吧),而看到它的上下文“鸿鹄之志”后,就对它立马熟悉了起来。因此,字/词的上下文信息有助于增强其语义表示。同时,上下文中的不同字/词对增强语义表示所起的作用往往不同。比如在上面这个例子中,“鸿”字对理解“鹄”字的作用最大,而“之”字的作用则相对较小。为了有区分地利用上下文字信息增强目标字的语义表示,就可以用到Attention机制。

    Attention机制主要涉及到三个概念:Query、Key和Value。在上面增强字的语义表示这个应用场景中,目标字及其上下文的字都有各自的原始Value,Attention机制将目标字作为Query、其上下文的各个字作为Key,并将Query与各个Key的相似性作为权重,把上下文各个字的Value融入目标字的原始Value中。

    如图16所示,Attention机制将目标字和上下文各个字的语义向量表示作为输入,首先通过线性变换获得目标字的Query向量表示、上下文各个字的Key向量表示以及目标字与上下文各个字的原始Value表示,然后计算Query向量与各个Key向量的相似度作为权重,加权融合目标字的Value向量和各个上下文字的Value向量,作为Attention的输出,即:目标字的增强语义向量表示。

图16 Attention机制架构样例

    除了基础的Attention机制,还有Self-Attention与Multi-head Attention两种机制。简介如下:

    Self-Attention:对于输入文本,我们需要对其中的每个字分别增强语义向量表示,因此,我们分别将每个字作为Query,加权融合文本中所有字的语义信息,得到各个字的增强语义向量。在这种情况下,Query、Key和Value的向量表示均来自于同一输入文本,因此,该Attention机制也叫Self-Attention。

    Multi-head Self-Attention:为了增强Attention的多样性,文章作者进一步利用不同的Self-Attention模块获得文本中每个字在不同语义空间下的增强语义向量,并将每个字的多个增强语义向量进行线性组合,从而获得一个最终的与原始字向量长度相同的增强语义向量。

    两种改进的Attention架构如图17所示:

 图17 Self-Attention与Multi-head Attention的架构样例

2.GAT简介

    根据表5中总结了GCN的两大局限,GAT引入了图17里的注意力机制,很好地解决了这两个缺点。

    Graph Attention Network(GAT)提出了用注意力机制对邻近节点特征加权求和。 邻近节点特征的权重完全取决于节点特征,独立于图结构。GAT和GCN的核心区别在于如何收集并累和距离为1的邻居节点的特征表示。 图注意力模型GAT用注意力机制替代了GCN中固定的标准化操作。具体来说,GAT借鉴了Transformer的idea,引入masked self-attention机制,在计算图中的每个节点的表示的时候,会根据邻居节点特征的不同来为其分配不同的权值。本质上, GAT只是将原本GCN的标准化函数替换为使用注意力权重的邻居节点特征聚合函数

    因此,GAT的优点总结如表6:

表6 经典GAT优点

1.训练GCN无需了解整个图结构,只需知道每个节点的邻居节点

2.计算速度快,可以在不同的节点上进行并行计算

3.既可以用于Transductive Learning,又可以用于Inductive Learning,可以对未见过的图结构进行处理

3.核心公式解析

    GAT本质上是引入注意力机制的图神经网络,所以核心还是图注意力层,在本部分对相关核心公式做解析。

3.1 图注意力层的输入输出

3.2 特征提取与注意力机制

图18 GAT特征提取与注意力机制过程

3.3 输出特征 

3.4 multi-head attention 

图19 GAT中multi−headattention样例

    分析:如图19所示,由节点在其邻域上的multi−headattention(具有K=3个头)的图示。不同的箭头样式和颜色表示独立的注意力计算,来自每个头的聚合特征被连接或平均以获得

    至此,GAT图注意力层的核心公式解析完成,对于GAT分类任务,过程与GCN的分类过程十分相似,均是采用softmax函数+交叉熵损失函数+梯度下降法来完成的,具体流程前文已有详述。

4.论文结果与分析

    论文作者对GAT模型与各种强大的基线和以前的方法进行了比较评估,在四个既定的基于图的基准任务(直推(transductive learning)和归纳(Inductive learning)式)上,在这些任务中都达到或符合最先进的性能。本节总结了实验设置、结果,以及对GAT模型提取的特征表示的简要定性分析。

4.1 数据集与比较方法

    对于直推式学习,论文利用三个标准引文网络基准数据集Cora、Citeseer和Pubmed,并密切遵循Yang等人的直推式实验设置,比较方法选择Kipf&Welling中规定的强基线和最先进的方法。

    对于归纳式学习,论文利用蛋白质-蛋白质相互作用(PPI)数据集,比较了Hamilton等人提出的四种不同的有监督GraphSAGE归纳方法。

    此外,对于这两项任务,我们提供了每节点共享多层感知器(MLP)分类器的性能(它根本不包含图结构)。

    论文使用的数据集及其信息总结如图20:

图20 GAT使用的数据集及其信息总结

4.2 实验模型设置

    对于直推式学习任务,实验设置如表7所示:

表7 transductive learning实验设置

  • 两层GAT模型
  • 在Cora数据集上优化网络结构的超参数,应用到Citeseer数据集
  • 第一层8head,F'=8,ELU作为非线性函数
  • 第二层为分类层,一个attention head特征数C,后跟 softmax函数,为了应对小训练集,正则化(L2)
  • 两层都采用0.6的dropout,相当于计算每个node位置的卷积时都是随机的选取了一部分近邻节点参与卷积

    对于归纳式学习任务,实验设置如表8所示:

表8 inductive learning实验设置

  • 三层GAT模型
  • 前两层 K=4, F1=256,ELU作为非线性函数
  • 最后一层用来分类K=6,F=121,激活函数为sigmoid
  • 该任务中,训练集足够大不需要使用正则化和dropout

    两个模型都使用Glorot初始化,并使用Adam SGD优化器对训练节点进行训练以最小化交叉熵,Pubmed的初始学习率为0.01,所有其他数据集的初始学习率为0.005。在这两种情况下,在验证节点的交叉熵损失和准确度(直推式)或微F1-score(归纳式)分数上使用提前停止策略,共100个epochs

4.3 实验结果与分析

    对于直推式任务,论文在100次运行后报告了测试节点上的平均分类精度(带标准偏差),并将Kipf&Welling和Monti et al.中已报告的指标用于最先进的技术。Cora、Citeser和Pubmed的分类准确度结果汇总如图21所示:

    对于归纳式任务,论文报告了两个看不见的测试图节点上的微平均F1-score,10次运行后平均,并将Hamilton等人中已报告的指标用于其他技术。PPI数据集微平均F1-score结果汇总如图22所示:

图21 Cora、Citeser和Pubmed的分类准确度结果汇总

图22 PPI数据集微平均F1-score结果汇总

    分析:如图21与图22,结合相关工作讨论分析,论文的结果成功地证明了在所有四个数据集中实现或匹配的最新性能符合预期

4.4 论文结论补充

    对于GAT可做一定的可视化,定性地研究学习到的特征表示的有效性。为此,论文提供了t-SNE的可视化通过在Cora数据集上预训练的GAT模型的第一层提取的转换特征表示,如图23所示。该表示在投影的2D空间中表现出明显的聚类。注意,这些集群对应于数据集的七个标签,验证了模型在Cora的七个主题类中的区分能力。此外,还可视化了归一化注意系数的相对强度(所有八个注意头的平均值)。

图23 Cora数据集上预训练的GAT模型第一个隐藏层的计算特征表示的t-SNE图

四、GAT复现

    在介绍完图神经网络(GNN)、图卷积神经网络(GCN)与图注意力网络(GAT)的原理、模型实现、优缺点后,在本部分对GAT的模型进行复现,并用其在Cora、Citeser数据集下进行实践,完成引文分类等实际任务。

    对于GAT的复现有多种方法,可以基于Tensorflow、Pytorch、keras等,官方代码地址总结如表9,本部分以Tensorflow与Pytorch框架对GAT进行复现与分析。

表9 GAT官方实现总结

框架

地址

Tensorflow

GitHub - PetarV-/GAT: Graph Attention Networks 

Pytorch

GitHub - gordicaleksa/pytorch-GAT 

Keras

GitHub - danielegrattarola/keras-gat 

1.Tensorflow GAT(代码解析)

    本部分尝试使用Tensorflow GAT进行论文实验复现与分析,包括核心代码的解析与复现结果的比较分析。

1.1 代码结构与参数设置

    将官方Tensorflow GAT导入Pycharm,查看代码结构与参数设置定义于GAT/execute_cora.py,如图24所示:

图24 Tensorflow GAT 代码结构与参数设置

1.2 数据加载与特征预处理

    Tensorflow GAT将数据加载与预处理部分定义在GAT/utils/process.py。其中,默认使用Cora数据集做引文分类,预处理部分GCN一致,最终载入的数据adj为邻接矩阵,表示2708篇文章之间的索引关系。feature表示1433个单词在2708篇文章中是否存在

1.3 GAT模型

    GAT 网络本身是通过堆叠多个 Graph Attention Layer 层构成的, 所以模型核心为Graph Attention Layer的实现,定义在layers.py

    其中,Graph Attention Layer 的定义在三(3)的核心公式解析中有详解,核心步骤是按生成对应的Attention系数。在代码实现中,核心函数为attn_head(),定义与关键解析如图25所示(Carbon | Create and share beautiful images of your source code渲染):

图25 attn_head()定义与关键解析

    如图25,已对Graph Attention Layer 层中核心代码与关键代码与流程做了简介,接下来再对图25中的重点代码做一定解释与补充。

图26 layer卷积

    如图26所示,作者首先对原始节点特征seq利用卷积核大小为1的1D卷积模拟投影变换得到了seq_fts,投影变换后的维度为out_sz。注意,这里投影矩阵W是所有节点共享,所以1D卷积中的多个卷积核也是共享的。

    而对于tensorflow中concld(卷积)的实现,补充如图27所示:

图27 卷积的实现过程可视化

图28 投影变换卷积

    如图28所示,投影变换后得到的seq_fts继续使用卷积核大小为1的1D卷积处理,得到节点本身的投影f_1 和其邻居的投影f_2。注意这里两个投影的参数是分开的,即有两套投影参数,分别对应上面两个conv1d 中的参数。

图29 logits矩阵

    如图29所示,将f_2转置之后与f_1叠加,通Tensorflow的广播机制得到logits,就是一个注意力矩阵,相关公式如式17所示:

式17 注意力矩阵定义

图30 注意力系数求解

    如图30所示,代码中的coefs即注意力系数,通过求解,即logits进行softmax归一化

    但注意力系数求解代码中增加了一项bias_mat,因为的logits存储了任意两个节点之间的注意力值,但是,归一化只需要对每个节点的所有邻居的注意力进行(k∈Ni)。所以,引入了bias_mat就是softmax的归一化对象约束在每个节点的邻居上,如式18的红色部分。

式18 约束注意力系数

    接下来对于bias_mat的实现做简单解析,它的实现定义在GAT/utils/process.py中,定义与核心解析如图31所示(Carbon | Create and share beautiful images of your source code渲染):

图31 bias_mat实现及核心解析

    分析:由图31所示,GAT对于bias_mat的实现使用了很大的负数,将原始邻居矩阵进行adj_to_bias,然后,将bias_mat和注意力矩阵相加,进而将非节点邻居进行 mask。

图32 节点更新

    如图32所示,根据GAT最终输出预测,最后将mask之后的注意力矩阵coefs与变换后的特征矩阵seq_fts相乘,即可得到更新后的节点表示vals。

1.4 GAT网络

    在GAT模型定义与构建成功后,就要开始GAT网络的搭建,其代码定义于GAT/models/gat.py,如图33所示,本质上是堆叠 Graph Attention Layer

图33 Tensorflow GAT网络定义

1.5 训练与损失函数

    Tensorflow GAT在GAT/models/base_gattn.py定义了训练与损失函数,训练的核心是最小化损失函数与L2 loss,函数定义如图34所示:

图34 Tensorflow GAT训练与损失函数定义

    至此,Tensorflow GAT的关键代码与算法流程就解析完成了。 

2.Pytorch GAT(代码解析)

    Pytorch GAT与Tensorflow GAT基本流程均符合GAT流程,在函数的实现与选择上有一定区别,与框架有关,故本部分仅对Pytorch核心代码(注意力层)进行解析,对于Pytorch GAT的注意力层代码实现如图35所示:

图35 Pytorch GAT的注意力层代码实现

3.代码实践与结果分析

    本部分分别使用Tensorflow GAT与Pytorch GAT对CoraCiteseer数据集进行引文分类并进行结果的分析与比较,GAT进行引文分类流程如图36所示,训练过程如图37所示:

图36 GAT进行引文分类流程

图37 GAT部分训练过程

    训练完成后,分别使用Tensorflow GAT与Pytorch GAT对Cora、Citeseer进行引文分类,以Pytorch GAT对Cora分类为例,结果如图38所示:

 图38 GAT分类结果样例

    Tensorflow GAT与Pytorch GAT对Cora、Citeseer进行引文分类的复现结果及与论文结论数据汇总如表10所示,对比如图39所示:

表10 GAT分类复现与论文结果数据汇总

方法

Cora

Citeseer

论文

83.0 ± 0.7%

72.5 ± 0.7%

Tensorflow

82.45

71.96

Pytorch

84.39

72.97

图39 GAT分类复现与论文结果对比

    分析:根据表10与图39可见,复现GAT与论文GAT在Cora、Citeseer数据集下得到的结果基本一致(F1-分类指标等),有些许差异与环境、参数设置等有关。

    其中,对于训练过程的训练及测试指标做可视化可以感受GAT模型建立的过程,同时观察分类正确率、损失函数等随epoch的变换,以tensorflow GAT的训练过程可视化为例,如图40所示:

图40 tensorflow GAT的训练过程可视化(右图为Cora数据集)

    分析:如图40所示,GAT的训练过程可视化后较直观,可以看到随epoch增加,分类正确率不断增加,直至稳定。

五、GAT补充

1.GAT的推广

    GAT仅仅是应用在了单层图结构网络上,我们是否能将它推广到多层网络结构呢?

    这里我们假设一个有N层网络的结构,每层网络都定义了相同的节点,但是节点之间的关系有所差异。这样,我们就完成了一个多层网络的构建,他们共享相同的节点,但又分别具有不同的邻边,如果我们分别处理每一层视图,然后将他们得出的节点表示单纯相加的话,就可能会失去不同视图之间的协作关系,降低分类(预测)的精度。

    基于以上观点,这里提出了一种新的方法:首先在每一层单视图中应用GAT进行学习,并计算出每层视图的节点表示。之后在不同视图之间引入attention机制来让网络自行学习不同视图的权重。之后根据学习的权重,将各个视图加权相加得到全局节点表示并进行后续的诸如节点表示,链接预测等任务。

    同时,因为不同视图共享同样的节点,即使每一层视图都表示了不同的节点关系,最终得到的每一层的节点嵌入表示应具有一定的相关性。基于以上理论,我们在每层GAT的网络参数间引入正则化项来约束参数,使其向互相相近的方向学习。大致的网络流程图如图41所示:

图41 多层网络GAT流程图

2.GAT优化

    参考论文:HOW ATTENTIVE ARE GRAPH ATTENTION NETWORKS?(ICLR 2022)

    针对传统GAT的原理、模型、实现进行分析,总结其局限如表11所示:

表11 传统GAT的局限

1.GAT在聚合多阶邻居时的不足

2.GAT最好加入self loop

(仅仅使用邻居消息聚合来训练出节点embedding的方式往往会引入大量噪音)

3.GAT的训练对参数初始化比较敏感

4.注意力计算时必不可少的LeakyReLU

    因此,GAT还存在优化的空间,本部分主要引用ICLR 2022的一篇论文思想对于经典GAT提出几个优化思路。

2.1 论文简介

    论文认为GAT是static attention,仅实现了对节点重要度的静态ranking,而未实现对不同query给出不同key的设想;故提出GATv2,通过调整LeakyReLU和linear unit计算顺序,实现dynamic attention,即对不同query能给出不同key

2.2 动机

    GAT已成为图神经网络发展历程中的标志性架构,但论文观察发现,GAT的attention对于相同的keys实现的其实是ranking。

    假设有Dictionary Lookup,问题与使用GAT所得的attention scores如图42所示:

图42 Dictionary Lookup问题与GAT所得attention scores

    分析:如图42,可以看到,对于不同的query,key的scores排序实际是一样的(静态的)。这限制了GAT的表达能力。

    而论文认为,attention的初衷应该是:给定不同的query,能找到不同的key(即不同query,ranking结果应该不同,动态的)。

2.3 方法论 

2.4 优化效果

    在Dictionary Lookup方面,用GATv2去解决图42的问题,优化前后对比如图43所示:

图43 优化前后GAT所得attention scores

    分析:由图43所示,对于上文中二部图问题,使用改进后的GAT能有效实现dynamic attention

    在Robustness to Noise方面,GATv2与经典GAT对比如图44所示:

图44 优化前后GAT抵抗噪声效果对比

    分析:由图44所示,dynamic attention(GATv2)能更好抵抗噪声。 

六、实验结论与参考

1.实验结论

(1)现实生活中有很多不规则的数据结构,典型的就是图结构,如社交网络、分子结构、知识图谱等,对非结构化的数据,传统NN方法表现不佳,需要使用图神经网络

(2)GAT本质上是一种加入了注意力机制的图卷积神经网络,在此之前的一些经典图神经网络缺陷如表12所示:

表12 经典图神经网络缺陷

GNN缺陷

1.对不动点使用迭代的方法来更新节点的隐藏状态,效率不高。

2.原始GNN 在迭代中使用相同参数,模型难以学习到更加深的特征表达。而且,节点隐藏层的更新是顺序流程。

3.一些边上可能会存在某些信息特征不能被有效地考虑进去。此外,如何学习边的隐藏状态也是一个重要问题。

4.如果我们需要学习节点的向量表示而不是图的表示,则不适合使用固定点,因为固定点中的表示分布将在值上非常平滑并且用于区分每个节点的信息量较少。

GCN缺陷

GCN对于同阶的邻域上分配给不同的邻居的权重是完全相同的,这一点限制了模型对于空间信息的相关性的捕捉能力

GCN结合临近节点特征的方式和图的结构依依相关,这局限了训练所得模型在其他图结构上的泛化能力

(3)GAT可以在Tensorflow、Pytorch、Keras等多种框架下完成多种实际任务,如引文分类等,同时GAT可以推广到多层网络结构

(4)GAT是一种高效的算法,但仍存在不适合聚合多阶邻居、需要加入self loop、对参数初始化比较敏感、注意力计算时必不可少LeakyReLU等问题,仍存在优化空间。

2.参考资料

1.【图表示学习】pytorch实现图注意力网络GAT_BQW_的博客-CSDN博客_gat pytorch实现

2.图神经网络(三)—GAT-pytorch版本代码详解_Arvin Ou的博客-CSDN博客_gat pytorch

3.【图结构】之图神经网络GCN详解_張張張張的博客-CSDN博客_gcn结构

4.【图结构】之图注意力网络GAT详解_張張張張的博客-CSDN博客_gat公式

5.GAT 算法原理介绍与源码分析_珍妮的选择的博客-CSDN博客_gat算法

6.Graph Attention Network (GAT) 的Tensorflow版代码解析_酒酿小圆子~的博客-CSDN博客

7.图神经网络GAT tensorflow代码分析_十三言的博客-CSDN博客_gat代码 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_239739.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt之悬浮球菜单

一、概述 最近想做一个炫酷的悬浮式菜单,考虑到菜单展开和美观,所以考虑学习下Qt的动画系统和状态机内容,打开QtCreator的示例教程浏览了下,大致发现教程中2D Painting程序和Animated Tiles程序有所帮助,如下图所示&a…

论文投稿指南——中文核心期刊推荐(自然科学总论)

【前言】 🚀 想发论文怎么办?手把手教你论文如何投稿!那么,首先要搞懂投稿目标——论文期刊 🎄 在期刊论文的分布中,存在一种普遍现象:即对于某一特定的学科或专业来说,少数期刊所含…

Nginx学习笔记2【尚硅谷】

host文件修改时,可以更改用户组权限或者复制到某个有权限的位置修改完再复制替换之前的文件。 在server{}中,listenserver_name两个加一起是唯一的。 代理服务器就是一个网关。 配置Nginx反向代理: 注意:在写proxy_pass时&#xf…

化学试剂Biotin-PEG-COOH,Biotin-PEG-acid,生物素-聚乙二醇-羧基

英文名称:Biotin-PEG-COOH,Biotin-PEG-acid 中文名称:生物素-聚乙二醇-羧基 生物素-PEG-COOH是一种含有生物素和羧酸的线性杂双功能PEG试剂。它是一种有用的带有PEG间隔基的交联或生物结合试剂。生物素能以高特异性和亲和力与亲和素和链霉亲…

MySQL实现主从复制(Windows)的明细操作步骤

文章目录一、教学视频地址二、设计思路三、具体步骤一、教学视频地址 视频地址:视频链接 二、设计思路 准备两个5.7版本的MySQL,一个用作主数据库,另一个用作从数据库。 把主数据库做为写入数据库,从数据库作为读数据库。 三…

linux篇【12】:计算机网络<后序>

一.tcp接入线程池(使用线程池) 1.tcp初步接入线程池 我们设置了对应的任务是死循环,那么线程池提供服务,就显得有不太合适。我们给线程池抛入的任务都是短任务 因为他并没有访问任何类内成员,所以可以把执行方法提到…

seo综合查询,怎么看网站在移动端权重高低

移动权重就是指在手机、IPAD等的流量,数值越大流量越多。 未来百度流量一定会更倾向于移动端,移动端搜索将是百度搜索引擎的主要阵地。这一点和用户上网习惯有关系,因为移动网络无处不在。 那么怎么看网站在移动端权重高低?最…

了解学习node中著名的co模块原理,生成器+promise实现async+await

***内容预警*** 新手内容,菜鸟必看,大佬请绕道 首先 co 是一个npm第三方模块,我们需要npm install 之后才能使用它。 作为一个菜鸟我相信你肯定没有用过这个模块,但是据说这个模块很有名,那么我们就有必要来了解一下它…

为什么企业要注重数据安全?六大优势分析

数据加密是将数据从可读格式转换为编码格式。两种最常见的加密方法是对称加密和非对称加密。这些名称是指是否使用相同的密钥进行加密和解密: ●对称加密密钥:这也称为私钥加密。用于编码的密钥与用于解码的密钥相同,使其最适合个人用户和封…

使用Docker搭建Nacos的持久化和集群部署

1. 准备 1.1 mysql安装 下载镜像 docker pull mysql/mysql-server:5.7 在宿主机中相关目录,用于挂载容器的相关数据 mkdir -p /data/mysql/{conf,data} 编写my.cnf配置文件,在/data/mysql/conf目录中 (或下载 直接上传即可) my.cnf.txt - 蓝奏云 / …

BIT.4 Linux进程控制

目录进程创建fork函数初识写实拷贝fork常规用法fork调用失败的原因补充知识进程终止进程退出场景进程常见退出方法exit函数与_exit函数return 退出补充知识进程等待进程等待必要性进程等待的方法wait方法waitpid方法wait / waitpid 阻塞代码WIFEXITEDwait / waitpid 非阻塞代码…

客快物流大数据项目(九十八):ClickHouse的SQL函数

文章目录 ClickHouse的SQL函数 一、​​​​​​​​​​​​​​类型检测函数

Redis集群之AKF架构原理

当我们搭建集群之前,先要想明白需要解决哪些问题,搞清楚这个之前先回想一下单节点、单实例、单机有哪些问题? 单点故障:只有一台Redis的话,如果出现故障,那么整个服务都不可用缓存容量:单台Red…

【Django】第一课 基于Django超市订单管理系统开发

概念 django服务器开发框架是一款基于Python编程语言用于web服务器开发的框架,采用的是MTV架构模式进行分层架构。 项目搭建 打开pycharm开发软件,打开开发软件的内置dos窗口操作命令行 在这里指定项目存放的磁盘路径,并使用创建django项…

54三数之和55 56有无重复元素的全排列

54 三数之和 首先想到的就是之前的两数之和,只要在外层遍历一遍,对每个元素用之前的两数之和的哈希做法,就刚好是O(n^2) 但是有坑的地方在于需要去重,并且输出的三元组也是需要顺序的!!然后我用set去重和重…

史上最强,这份在各大平台获百万推荐的Java核心手册实至名归

又逢“金九银十”,年轻的毕业生们满怀希望与忐忑,去寻找、竞争一个工作机会。已经在职的开发同学,也想通过社会招聘或者内推的时机争取到更好的待遇、更大的平台。 然而,面试人群众多,技术市场却相对冷淡,…

flutter 环境搭建

一、简介 Flutter 是谷歌开发的一款开源、免费的,基于 Dart 语言的U1框架,可以快速在i0S和Android上构建高质量的原生应用。 它最大的特点就是跨平台和高性能。Dart是由谷歌,在2011 年开发的计算机编程语言,它可以被用于Web、服务器、移动应…

服务注册配置中心Nacos

文章目录一. 前言二. 下载安装1. 下载安装包2. Windows环境安装3. Linux环境安装1. 单击模式启动2. 集群模式启动3. 远程web控制4. 注册为系统服务三. 基本使用1. 添加依赖2. 服务注册3. 配置实例集群属性4. 实例权重负载均衡5. 环境隔离6. 临时实例与非临时实例四. Nacos配置管…

python常用模块

time模块 常用操作 1.直接获取时间 time.time() #获取结果是秒数,即从1970年1月1日8:00起计#1671856010.9592516 2.获取结构化时间 time.localtime() #获取本地时间,中国为东八区,为上海时间 time.gmtime() …

3.2 Static Terrestrial Laser Scanners 静态地基激光扫描仪

本章节介绍的静态地基激光扫描系统指的是那些在一个固定位置的位置上对周边场景地物特征进行扫描的设备。该类型设备的扫描测量机制是,通过激光测距仪进行斜距测量,与此同时通过水平和竖直两个方向上同步运动的角度编码器来记录角度变化值(如…