图注意网络GAT理解及Pytorch代码实现【PyGAT代码详细注释】

news/2024/4/18 10:56:38/文章来源:https://blog.csdn.net/weixin_43629813/article/details/129278266

文章目录

    • GAT
    • 代码实现【PyGAT】
      • GraphAttentionLayer【一个图注意力层实现】
      • 用上面实现的单层网络测试
      • 加入Multi-head机制的GAT
      • 对数据集Cora的处理
      • csr_matrix()处理稀疏矩阵
      • encode_onehot()对label编号
      • build graph
      • 邻接矩阵构造
    • GAT的推广

GAT

题:Graph Attention Networks
摘要:
提出了图形注意网络(GAT) ,这是一种基于图结构数据的新型神经网络结构,利用掩蔽的自我注意层来解决基于图卷积或其近似的先前方法的缺点。通过叠加层,节点能够参与其邻域的特征,我们能够(隐式地)为邻域中的不同节点指定不同的权重,而不需要任何代价高昂的矩阵操作(如反演) ,或者依赖于预先知道图的结构。通过这种方法,我们同时解决了基于谱的图形神经网络的几个关键问题,并使我们的模型容易地适用于归纳和转导问题。我们的 GAT 模型已经实现或匹配了四个已建立的转导和归纳图基准的最新结果: Cora,Citeseer 和 Pubmed 引用网络数据集,以及protein-protein interaction dataset(其中测试图在训练期间保持不可见)。

Paper with code 网址,可找到对应论文和github源码,原论文使用TensorFlow实现,本篇主要对Pytorch版本的 PyGAT附详细注释帮助理解和测试。

GitHUb: keras版本实现
Pytorch版本实现 PyGAT

在这里插入图片描述
在这里插入图片描述

截图及下文代码注释参考自视频:GAT详解及代码实现

视频中的eij的实现与源码不同,视频中是先拼接两个W,再与a乘
源码在_prepare_attentional_mechanism_input()函数中先分别与a乘,再拼接

代码实现【PyGAT】

在PyGAT :

  • layers.py中定义Simple GAT layer实现(GraphAttentionLayer)和Sparse version GAT layer实现(SpGraphAttentionLayer)。
  • models.py 实现两个版本加入Multi-head机制
  • trains.py 使用model定义的GAT构建模型进行训练,使用cora数据集

GraphAttentionLayer【一个图注意力层实现】

class GraphAttentionLayer(nn.Module):"""Simple GAT layer, similar to https://arxiv.org/abs/1710.10903"""def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GraphAttentionLayer, self).__init__()self.dropout = dropoutself.in_features = in_features#结点向量的特征维度self.out_features = out_features#经过GAT之后的特征维度self.alpha = alpha#dropout参数self.concat = concat#LeakyReLU参数# 定义可训练参数,即论文中的W和aself.W = nn.Parameter(torch.empty(size=(in_features, out_features)))nn.init.xavier_uniform_(self.W.data, gain=1.414)# xavier初始化self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))nn.init.xavier_uniform_(self.a.data, gain=1.414)# xavier初始化# 定义leakyReLU激活函数self.leakyrelu = nn.LeakyReLU(self.alpha)def forward(self, h, adj):'''adj图邻接矩阵,维度[N,N]非零即一h.shape: (N, in_features), self.W.shape:(in_features,out_features)Wh.shape: (N, out_features)'''Wh = torch.mm(h, self.W) # 对应eij的计算公式e = self._prepare_attentional_mechanism_input(Wh)#对应LeakyReLU(eij)计算公式zero_vec = -9e15*torch.ones_like(e)#将没有链接的边设置为负无穷attention = torch.where(adj > 0, e, zero_vec)#[N,N]# 表示如果邻接矩阵元素大于0时,则两个节点有连接,该位置的注意力系数保留# 否则需要mask设置为非常小的值,因为softmax的时候这个最小值会不考虑attention = F.softmax(attention, dim=1)# softmax形状保持不变[N,N],得到归一化的注意力全忠!attention = F.dropout(attention, self.dropout, training=self.training)# dropout,防止过拟合h_prime = torch.matmul(attention, Wh)#[N,N].[N,out_features]=>[N,out_features]# 得到由周围节点通过注意力权重进行更新后的表示if self.concat:return F.elu(h_prime)else:return h_primedef _prepare_attentional_mechanism_input(self, Wh):# Wh.shape (N, out_feature)# self.a.shape (2 * out_feature, 1)# Wh1&2.shape (N, 1)# e.shape (N, N)# 先分别与a相乘再进行拼接Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])# broadcast adde = Wh1 + Wh2.Treturn self.leakyrelu(e)def __repr__(self):return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

用上面实现的单层网络测试

x = torch.randn(6,10)
adj=torch.tensor([[0,1,1,0,0,0],[1,0,1,0,0,0],[1,1,0,1,0,0],[0,0,1,0,1,1],[0,0,0,1,0,0,],[0,0,0,1,1,0]])
my_gat = GraphAttentionLayer(10,5,0.2,0.2)
print(my_gat(x,adj))
输出:
tensor([[-0.2965,  2.8110, -0.6680, -0.9643, -0.9882],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],[-0.4981, -0.7515,  1.1159,  0.3546,  1.3592],[ 0.4679,  1.7208,  0.3084, -0.5331, -0.1291],[-0.4375, -0.8778,  1.1767, -0.5869,  1.5154],[-0.2164, -0.5897,  0.4988, -0.3125,  0.6423]], grad_fn=<EluBackward>)

加入Multi-head机制的GAT

用不同head捕捉不同特征,使模型有更好的拟合能力。

class GAT(nn.Module):def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):"""Dense version of GAT."""super(GAT, self).__init__()self.dropout = dropout# 加入Multi-head机制self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]for i, attention in enumerate(self.attentions):self.add_module('attention_{}'.format(i), attention)self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)def forward(self, x, adj):x = F.dropout(x, self.dropout, training=self.training)x = torch.cat([att(x, adj) for att in self.attentions], dim=1)x = F.dropout(x, self.dropout, training=self.training)x = F.elu(self.out_att(x, adj))return F.log_softmax(x, dim=1)

对数据集Cora的处理

在这里插入图片描述
数据集中两个文件,cites:比如上图11行:编号25和编号1114331的文章
content文件:如下图,每篇文章的id、features及类别
在这里插入图片描述

csr_matrix()处理稀疏矩阵

utils.py中对数据进行的处理

#数据是稀疏的,csr_matrix操作从行开始将1的位置取出来,对数据进行压缩

features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
labels = encode_onehot(idx_features_labels[:, -1])

在这里插入图片描述

encode_onehot()对label编号

有7个类别,通过classes_dict是7*7的对角阵把每个类别映射成不同向量,对所有label进行编号,再将编号转换为one_hot向量
在这里插入图片描述

def encode_onehot(labels):# The classes must be sorted before encoding to enable static class encoding.# In other words, make sure the first class always maps to index 0.classes = sorted(list(set(labels)))classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)return labels_onehot

build graph

见注释:

# build graphidx = np.array(idx_features_labels[:, 0], dtype=np.int32)#获取所有文章ididx_map = {j: i for i, j in enumerate(idx)}#按文章数目,对id重新映射# 读取数据集中文章和文章直接的引用关系edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset), dtype=np.int32)# 根据idx_map,将文章引用关系也重新映射edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), dtype=np.int32).reshape(edges_unordered.shape)adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), shape=(labels.shape[0], labels.shape[0]), dtype=np.float32)# build symmetric adjacency matrix 生成邻接矩阵adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)features = normalize_features(features)adj = normalize_adj(adj + sp.eye(adj.shape[0]))

邻接矩阵构造

csr_matrix()只记录了(0,1)1,忽略了(1,0)1。所以需要coo_matrix()操作!才能还原出无向图的邻接矩阵!
在这里插入图片描述

本文的一些代码注释及截图还可见视频

一个拓展:

GAT的推广

GAT的推广
GAT仅仅是应用在了单层图结构网络上,我们是否可以将它推广到多层网络结构呢?

这里我们假设一个有N层网络的结构,每层网络都定义了相同的节点,但是节点之间的关系有所差异。举一个简单的例子,假设有一个用户关系型网络,每层网络的节点都被定义成了网络中的用户,网络的第一层视图的关系可以定义为,两个用户之间是否具有好友关系;网络的第二层视图可以定义为,你评论过我的动态;网络的第三层视图可以定义为你转发过我的动态;第四层关系可以定义为,你at过我等等。

通过这样的定义我们就完成了一个多层网络的构建,他们共享相同的节点,但又分别具有不同的邻边,如果我们分别处理每一层视图视图,然后将他们得出的节点表示单纯相加的话,就可能会失去不同视图之间的协作关系,降低分类(预测)的精度。

基于以上观点,我们提出了一种新的方法:首先在每一层单视图中应用GAT进行学习,并计算出每层视图的节点表示。之后再不同视图之间引入attention机制来让网络自行学习不同视图的权重。之后根据学习的权重,将各个视图加权相加得到全局节点表示并进行后续的诸如节点表示,链接预测等任务。

同时,因为不同视图共享同样的节点,即使每一层视图都表示了不同的节点关系,最终得到的每一层的节点嵌入表示应具有一定的相关性。基于以上理论,我们在每层GAT的网络参数间引入正则化项来约束参数,使其向互相相近的方向学习。大致的网络流程图如下:

这部分来源于 链接:https://www.jianshu.com/p/d5d366ba1a57 来源:简书

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_75939.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

操作指南:如何高效使用Facebook Messenger销售(二)

上一篇文章我们介绍了使用Facebook Messenger作为销售渠道的定义、好处及注意事项&#xff0c;本节我们将详细介绍怎么将Facebook Messenger销售与SaleSmartly&#xff08;ss客服&#xff09;结合&#xff0c;实现一站式管理多主页配图来源&#xff1a;SaleSmartly&#xff08;…

Meta开放小模型LLaMA,性能超过GPT-3

论文地址&#xff1a;https://research.facebook.com/file/1574548786327032/LLaMA--Open-and-Efficient-Foundation-Language-Models.pdf 介绍 LLaMA&#xff0c;是Meta AI最新发布的一个从7B到65B参数的基础语言模型集合。在数以万亿计的token上训练模型&#xff0c;并表明…

王道操作系统课代表 - 考研计算机 第二章 进程与线程 究极精华总结笔记

本篇博客是考研期间学习王道课程 传送门 的笔记&#xff0c;以及一整年里对 操作系统 知识点的理解的总结。希望对新一届的计算机考研人提供帮助&#xff01;&#xff01;&#xff01; 关于对 “进程与线程” 章节知识点总结的十分全面&#xff0c;涵括了《操作系统》课程里的全…

TPS74525PQWDRVRQ1典型应用TPS62992QRYTRQ1汽车用稳压器 规格参数

TPS74525PQWDRVRQ1线性稳压器 IC 2.5V 500MA 6WSON明佳达电子【概述】TPS745/TPS745-Q1可调节500mA LDO稳压器具有极低的静态电流&#xff0c;并可提供快速的线路和负载瞬态性能。TPS745/TPS745-Q1具有130mV的超低压差&#xff08;500mA电流&#xff09;&#xff0c;这有助于提…

MyBatis学习笔记(九) —— 自定义映射resultMap

9、自定义映射resultMap 9.1、resultMap处理字段和属性的映射关系 若字段名和实体类中的属性名不一致&#xff0c;则可以通过resultMap设置自定义映射 resultType 是一个具体的类型。 resultMap 是resultMap的标签。 id 是处理主键和属性的映射关系&#xff1b; result 是处…

HOT100--(3)无重复字符的最长子串

点击查看题目详情 大思路&#xff1a; 创建哈希表&#xff0c;元素类型为<char, int>&#xff0c;分别是字符与其对应下标 用哈希表来存储未重复的子串&#xff0c;若有重复则记录下当前子串最大值maxhashsize 并且开始以相同方法记录下一子串 遍历完成以后&#xff0c…

【项目精选】进销存管理系统的设计与实现(视频+源码+论文)

点击下载源码 1.1研究背景和意义 目前&#xff0c;许多的中小企业普遍存在一个问题&#xff1a;企业的决策者看到的进销存资料及相关报表都是比较繁杂&#xff0c;让本应该一目了然的结果因信息的分散使得产生的结果无法保持一致和完整&#xff0c;造成企业在进销存管理上问题很…

T3 出行云原生容器化平台实践

作者&#xff1a;林勇&#xff0c;就职于南京领行科技股份有限公司&#xff0c;担任云原生负责人&#xff0c;也是公司容器化项目的负责人。主要负责 T3 出行云原生生态相关的所有工作&#xff0c;如服务容器化、多 Kubernetes 集群建设、应用混部、降本增效、云原生可观测性基…

有趣的 Kotlin 0x10:操作符 ..<

操作符 …< ..< 操作符是 Kotlin 在 1.7.20 版本中引入的不包含尾部元素的左闭右开区间操作符。之前我们使用的比较多的操作符可能是 .. 和 until&#xff0c;两者均表示区间&#xff0c;前者是闭区间&#xff0c;后者则表示不包含末端元素的左闭右开区间。 OptIn(Expe…

【微服务】Ribbon实现负载均衡

目录 1.什么是负载均衡 2.自定义负载均衡 3.基于Ribbon实现负载均衡 Ribbon⽀持的负载均衡策略 4.负载均衡原理 源码跟踪 LoadBalancerIntercepor LoadBalancerClient 5.负载均衡策略IRule 总结 1.什么是负载均衡 通俗的讲&#xff0c; 负载均衡就是将负载&#xff…

Java奠基】运算符的讲解与使用

目录 运算符与表达式的使用 算术运算符 隐式转换与强制转换 自增自减运算符 赋值运算符 关系运算符 逻辑运算符 三元运算符 运算符与表达式的使用 运算符是指&#xff1a;对字面量或者变量进行操作的符号。 表达式是指&#xff1a;用运算符把字面量或者变量连接起来&…

使用Geth搭建多节点私有链

使用Geth搭建多节点私有链 步骤 1.编辑初始化配置文件genesis.json {"config": {"chainId": 6668,"homesteadBlock": 0,"eip150Block": 0,"eip150Hash": "0x000000000000000000000000000000000000000000000000000000…

【工具插件类教学】UnityPackageManager私人定制资源工具包

目录 一.UnityPackageManager的介绍 二.package包命名 三.包的布局 四.生成清单文件 五.制作package内功能 六.为您的软件包撰写文档 1.信息的结构 2.文档格式 七.提交上传云端仓库 1.生成程序集文件 2.上传至云端仓库 八.下载使用package包 1.获取包的云端路径 …

Vue3 企业级项目实战:认识 Spring Boot

Vue3 企业级项目实战 - 程序员十三 - 掘金小册Vue3 Element Plus Spring Boot 企业级项目开发&#xff0c;升职加薪&#xff0c;快人一步。。「Vue3 企业级项目实战」由程序员十三撰写&#xff0c;2744人购买https://s.juejin.cn/ds/S2RkR9F/ 越来越流行的 Spring Boot Spr…

NJ+SCU42做Modbus RTU从站

NJSCU42做Modbus RTU从站实验时间&#xff1a;2023.2.28 硬件设备&#xff1a;NJ501-1300&#xff0c;CJ1W-SCU42 软件&#xff1a;Sysmac Studio&#xff0c;Commix串口调试助手 案例简介&#xff1a;发送Modbus RTU命令读取NJ里的数据 1. 系统概述 264 ​ 本次实验使用C…

回暖!“数”说城市烟火气背后

“人间烟火气&#xff0c;最抚凡人心”。在全国各地政策支持以及企业的积极生产运营下&#xff0c;经济、社会、生活各领域正加速回暖&#xff0c;“烟火气”在城市中升腾&#xff0c;信心和希望正在每个人心中燃起。 发展新阶段&#xff0c;高效统筹经济发展和公共安全&#…

[文件操作] File 类的用法和 InputStream, OutputStream 的用法

能吃是不是件幸福的事呢 文章目录前言1. 文件的相关定义2. 文件类型3. Java对文件系统的操作3.1 对文件的基础操作3.2 读文件3.3 写文件前言 从这章开始,我们就开始学文件操作相关的知识了~ 1. 文件的相关定义 1.文件的定义可以从狭义和广义两个方面解释. 狭义: 指硬盘上的文…

进程、线程、协程详解

目录 前言&#xff1a; 一、进程 进程的概念 进程内存空间 二、线程 线程的定义 内核线程 用户线程 内核线程和用户线程的比较 线程的状态 三、协程 协程的定义 协程序相对于线程优势 运用场景 四、线程、协程、进程切换比较 前言&#xff1a; 有时候无法…

原生JS实现拖拽排序

拖拽&#xff08;这两个字看了几遍已经不认识了&#xff09; 说到拖拽&#xff0c;应用场景不可谓不多。无论是打开电脑还是手机&#xff0c;第一眼望去的界面都是可拖拽的&#xff0c;靠拖拽实现APP或者应用的重新布局&#xff0c;或者拖拽文件进行操作文件。 先看效果图&am…

人力资源管理系统

技术&#xff1a;Java、JSP等摘要&#xff1a;在当今的信息化社会&#xff0c;为了更有效率地工作&#xff0c;人们充分利用现在的电子信息技术&#xff0c;在办公室架设起办公服务平台&#xff0c;将人力资源相关信息统一起来管理&#xff0c;帮助管理者有效组织降低成本和加速…