数据挖掘期末-图注意力模型

news/2024/5/11 2:10:37/文章来源:https://blog.csdn.net/Aaron503/article/details/128426214

PyGAT图注意力模型

​  PyGAT实现的分类器: https://www.aliyundrive.com/s/vfK8ndntpyc

  还在发烧,不是特别清醒,就简单写了写。用GAT进行关系预测,GAT可能是只做中间层,不过本来在GAT这一层就为了能懂就简化了很多地方了,如果再加别的,预测正确率大概率很低。尝试了直接用GAT预测边权(没用稀疏矩阵的版本),内存不够没办法跑(需要至少100G+),试了少一些节点,也基本预测不出来,所以这里只介绍GAT基本实现和GAT进行分类。

​   代码是用ANACONDA的IPYTHON虚拟环境里运行的。

​   需要安装pytorch,在conda prompt输入以下命令:

conda install pytorch torchvision torchaudio cpuonly -c pytorch

​   不用conda可以↓,选Pip,得到Pip的命令安装PyTorch。CUDA没有对应的版本就选CPU。

​   https://pytorch.org/
在这里插入图片描述

GAT

​   GAT对于一个图,按照其输入的节点特征预测输出新的节点的特征。

​   GAT是一种能够直接作用于图并且利用其结构信息的卷积神经网络,可用于网络中的半监督学习问题,学习网络中结点的特征与网络结构的信息。主要思想是对每个结点的邻居及其自身的信息作加权平均,按照其输入的节点特征输出新的节点特征。图注意力模型用注意力机制对邻近节点特征加权求和,每个节点可以根据邻节点的特征,为其分配不同的权值,将权重称为注意力系数,然后根据注意力系数进行加权求和,得到节点的新特征。

​   图注意力模型训练得到的是一个计算好节点的注意力权重的图,得到该图以后,就可以输入节点特征,得到新的节点特征。这个新的特征是什么取决于使用图注意力模型的目的。

GAT模型

  假设输入特征维度为x,输出特征维度为y。

  GAT是图注意力层+MUTIHEAD机制,MUTIHEAD机制相当于有多个图注意力层,每个层输出是n1,n2,n3…个新特征,最后再将这些特征转变为y个特征,也就是最终输出的特征

注意力系数计算

  构建图注意力层的第一步是计算注意力系数,对所有节点计算他的所有相邻节点的注意力系数。计算注意力系数公式如下:

在这里插入图片描述

•eij:节点i对邻居j的注意力系数

•W:可学习的参数矩阵

•a:可学习参数,一个向量,将多维特征转化为一个数

  得到注意力系数后,将该节点与所有邻居节点的注意力系数做归一化处理,就得到了注意力系数。这个归一化是指数归一化,并且归一化之前要使用L e a k y R e L U 进行非线性激活。归一后就得到了节点i和节点j的注意力系数。

在这里插入图片描述

聚合

  得到计算好的注意力系数,将特征进行加权求和,经激活函数激活后就得到了每个节点的新特征。

在这里插入图片描述

计算的例子

  https://zhuanlan.zhihu.com/p/412270208这里有一个图注意力模型的具体的计算的例子,可以看一下,就知道GAT是怎么由输入得到输出的了,这里的GAT代码是PyGAT的官方代码,对cora数据集分类,如果能看懂就知道GAT怎么用了,或者可以继续看下去,用了PyGAT的模型,训练的部分简化了,删除了稀疏矩阵运算的版本。

模型定义

  图注意力层的定义如下:

class GraphAttentionLayer(nn.Module):"""Simple GAT layer, similar to https://arxiv.org/abs/1710.10903"""def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GraphAttentionLayer, self).__init__()self.dropout = dropout								#dropout表示随机放弃多少邻节点,一般0.2self.in_features = in_features						#输入特征self.out_features = out_features				 	#输出特征self.alpha = alpha									#激活函数用的参数,0.2self.concat = concatself.W = nn.Parameter(torch.empty(size=(in_features, out_features)))	#参数矩阵Wnn.init.xavier_uniform_(self.W.data, gain=1.414)						#初始化self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))			#参数向量ann.init.xavier_uniform_(self.a.data, gain=1.414)						#初始化self.leakyrelu = nn.LeakyReLU(self.alpha)								#激活函数

  注意力层计算的过程在forward函数中,输入是h,原特征矩阵,输出是新特征矩阵。

    def forward(self, h, adj):Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)e = self._prepare_attentional_mechanism_input(Wh)						zero_vec = -9e15*torch.ones_like(e)attention = torch.where(adj > 0, e, zero_vec)							#邻接矩阵为0的位置表示没有边,注意力系数为0attention = F.softmax(attention, dim=1)									#指数归一化attention = F.dropout(attention, self.dropout, training=self.training)  #drop 0.2 的邻节点的注意力系数h_prime = torch.matmul(attention, Wh)									#输出特征if self.concat:return F.elu(h_prime)												#激活后的输出特征else:return h_prime

  GAT模型就是在图注意力层基础上MUTIHEAD机制的实现,但是MUTIHEAD机制不是图注意力模型必须的,只有1个图注意力层就可以算是一个图注意力模型了。
  下面看一下GAT模型的定义,MUTIHEAD机制的实现不重要,所以不看具体实现,只看参数。init中,nfeat是输入特征的维度,nhid是中间层,也就是如果有多个注意力层的情况下,注意力层的输出特征维度,nclass是最终的输出维度,因为这个GAT实现的是一个分类器,所以输出特征数就是类别数,即如果有n类,输出特征就应该有n个,每个输出值表示该样本是该类别的概率。例如每个样本都有一些自己的特征,一共有三类,把特征以及该样本和其他样本的关系输入到训练好的GAT,输出特征是0.1,0.5,0.4,就表示该样本是第一类,第二类,第三类的概率是0.1,0.5,0.4。至于为什么输出是概率,就涉及到模型训练的损失函数了,这部分内容可以看一下交叉熵相关的知识。顶着烧的我反正是没太看懂

class GAT(nn.Module):def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):def forward(self, x, adj):

模型训练

  GAT能把输入特征变换为输出特征,是需要训练的。GAT一共两个可变参数,参数矩阵W和参数a,已经在模型定义时用nn.Parameter声明成参数了,接下来只要定义训练函数,就能训练出合适的模型了。

  怎么训练出合适的模型,让模型能够得到预期的输出,取决于模型的任务。对于分类任务来说,应该用已有数据的标签来训练模型,让模型得到的样本标签能够尽量贴近真实标签。

  首先要实例化一个模型,假设实例化一个三分类GAT模型:

hidden = int(labels.max()) + 1                  #直接把中间层特征数也设置成输出特征数了
model = GAT(nfeat=features.shape[1], nhid=hidden,                        #中间层特征数,模型可以在每一层有不同的输出特征数nclass=int(labels.max()) + 1,       #类别即输出特征数dropout=0.2, nheads= 1, alpha=0.2)

  训练模型用的优化器也要实例化一个:

patience = 100 #100次LOSS没有下降,就停止训练
epochs = 1000
optimizer = optim.Adam(model.parameters(), lr=0.001,                #学习率,训练慢了可以调大一点 weight_decay=5e-4)

  单轮次的训练是这样的,就是用GAT算输出,然后用输出和正确结果计算损失,然后使用优化器进行优化,优化器会会优化参数。损失函数就是上面提到的交叉熵。其中labels是ONEHOT编码,例如有三个类别type1,type2,type3,编码就是001,010,100,至于为什么使用ONEHOT编码,这也是和交叉熵有关的。

def train(epoch):t = time.time()model.train()optimizer.zero_grad()output = model(features, adj)										# 算输出loss_train = F.nll_loss(output[idx_train], labels[idx_train])		# nll_loss损失函数acc_train = accuracy(output[idx_train], labels[idx_train])loss_train.backward()optimizer.step()													# 使用优化器优化loss_val = F.nll_loss(output[idx_train], labels[idx_train])acc_val = accuracy(output[idx_train], labels[idx_train])# 输出单轮训练结果print('Epoch: {:04d}'.format(epoch+1),'loss_train: {:.4f}'.format(loss_train.data.item()),'acc_train: {:.4f}'.format(acc_train.data.item()),'loss_val: {:.4f}'.format(loss_val.data.item()),'acc_val: {:.4f}'.format(acc_val.data.item()),'time: {:.4f}s'.format(time.time() - t))return loss_val.data.item()

  这样的训练进行多轮,直到所有轮次训练完或者损失函数算的loss值不再变换。

t_total = time.time()
loss_values = []
bad_counter = 0
best = epochs + 1
best_epoch = 0
for epoch in range(epochs):loss_values.append(train(epoch))torch.save(model.state_dict(), '{}.pkl'.format(epoch))if loss_values[-1] < best:best = loss_values[-1]best_epoch = epochbad_counter = 0else:bad_counter += 1if bad_counter == patience:breakfiles = glob.glob('*.pkl')for file in files:epoch_nb = int(file.split('.')[0])if epoch_nb < best_epoch:os.remove(file)

  train.py里用随机生成的数据进行了训练,运行一下就可以看到模型训练的过程结果,训练好的最佳参数的模型会保存到一个pkl文件里,最后的测试部分,会读入文件中的模型,测试中有一个idx_test是用于测试的数据的下标,我没有定义测试数据,所以将测试的这部分注释掉了。

  (原来的数据集中标签是一整份的,idx_train是用于训练的样本下标,idx_test是用于测试的样本下标,例如idx_train是0-4,表示labels[0:4]都是训练用的)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_239623.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux-系统随你玩之--用户及用户组管理

一、用户基本介绍 Linux 系统是一个多用户多任务的操作系统&#xff0c;任何一个要使用系统资源的用户&#xff0c;都必须首先向系统 管理员申请一个账号&#xff0c;然后才可以以这个用户登陆系统。 二、Linux中用户和组 2.1、用户和组介绍 用户&#xff1a; 每一个用户都…

如何不改一行代码,让Hippy启动速度提升50%?

导读&#xff5c;Hippy使用JS引擎进行异步渲染&#xff0c;在用户从点击到打开首屏可交互过程中会有一定的耗时&#xff0c;影响用户体验。如何优化这段耗时&#xff1f;腾讯客户端开发工程师李鹏&#xff0c;将介绍QQ浏览器通过切换JS引擎来优化耗时的探索过程和效果收益。在分…

微导纳米科创板上市:市值125亿 无锡首富王燕清再敲钟

雷递网 雷建平 12月23日江苏微导纳米科技股份有限公司&#xff08;简称&#xff1a;“微导纳米”&#xff0c;股票代码为&#xff1a;“688147”&#xff09;今日在科创板上市。微导纳米此次发行4544.55万股&#xff0c;发行价为24.21元&#xff0c;募资总额为11亿元。微导纳米…

对Python的学习【如何查看路径和安装包】

1&#xff1a;怎么查看本地电脑的Python版本号及安装路径&#xff1a; 对于Windows平台&#xff0c;打开cmd 使用命令py -0p 【其中0是零】 显示已安装的 python 版本且带路径的列表&#xff0c;参见下图&#xff1a; 其中带星号*的为默认版本。 2:怎么查看python pip…

认识 Fuchsia OS

认识 Fuchsia OS 1 说明背景 1.1 基本信息 开发者: Google编程语言: C、C、Rust、Go、Python、Dart内核: Zircon运作状态: 当前源码模式: 开放源代码初始版本: 2016年8月15日支持的语言: 英语支持平台: ARM64、X86-64内核类别: 微内核 基于能力 实时操作系统许可证: BSD 3 c…

腾讯焦虑了,一向温文尔雅的马化腾也发脾气了

大家好&#xff0c;我是校长。昨天小马哥内部讲话在互联网上疯传&#xff0c;这应该是&#xff0c;腾讯这家公司创办以来&#xff0c;马化腾最焦虑也最外露的一次讲话了&#xff0c;重点大概涉及 3 大方面&#xff0c;8 大项内容&#xff1a;1、所有业务线 ROI 化&#xff0c;再…

该怎么选择副业,三条建议形成自己的副业思维

受经济环境的影响&#xff0c;许多年轻人觉得原来稳定的工作不那么稳定&#xff0c;看着周围的朋友因为企业破产和失业&#xff0c;生活变得没有信心&#xff0c;也想找到自己的副业&#xff0c;在紧急情况下赚更多的钱。所以&#xff0c;年轻人在选择副业时也面临着很多困惑&a…

LeetCode HOT 100 —— 581. 最短无序连续子数组

题目 给你一个整数数组 nums &#xff0c;你需要找出一个 连续子数组 &#xff0c;如果对这个子数组进行升序排序&#xff0c;那么整个数组都会变为升序排序。 请你找出符合题意的 最短 子数组&#xff0c;并输出它的长度。 思路 方法一&#xff1a;双指针 排序 最终目的是让…

2023春季招聘面试集锦:MYSQL数据库高频面试题

mysql索引的数据结构&#xff0c;各自优劣 索引的数据结构和具体存储引擎的实现有关&#xff0c;在MySQL中使用较多的索引有Hash索引&#xff0c;B树索引等&#xff0c; InnoDB存储引擎的默认索引实现为&#xff1a;B树索引。对于哈希索引来说&#xff0c;底层的数据结构就是…

SpringBoot:模块探究之spring-boot-starters

Spring Boot Starters 是一组方便的依赖描述符&#xff0c;您可以将它们包含在您的应用程序中。您可以获得所需的所有 Spring 和相关技术的一站式服务&#xff0c;而无需搜索示例代码和复制粘贴大量依赖项描述符。 例如&#xff0c;如果想使用 Spring 和 JPA 进行数据库访问&am…

前端小知识:文本分句、词、字(Intl.Segmenter)

5. 文本分字、词、句 参考文章&#xff1a; https://mp.weixin.qq.com/s/MLmi-Yoi9sez8-5DPtcBVw   官方文档&#xff08;构造参数&#xff09;&#xff1a; https://developer.mozilla.org/zh-CN/docs/Web/JavaScript/Reference/Global_Objects/Intl/Segmenter/Segmenter   …

win环境mysql版本升级到5.7过程

win环境mysql版本升级到5.7过程&#xff0c;我win电脑里mysql版本是5.0&#xff0c;版本太老了&#xff0c;也不支持和nacos集成&#xff08;nacos至少需要5.6版本的mysql&#xff09;&#xff0c;思来想去还是要升级一下自己电脑的mysql版本&#xff0c;保守点升级到5.7吧&…

项目实战之旅游网(三)后台用户管理(下)

目录 一.查询用户角色 二.修改用户角色 三.修改用户状态 一.查询用户角色 一个用户可以有多个角色&#xff0c;我们也可以给某个用户分配某些角色&#xff0c;所以我们还需要新建一个实体类&#xff08;这个实体类需要放到bean下&#xff0c;因为这个实体类和数据据库不是对…

SpringCloud 网关组件 Zuul-1.0 原理深度解析

为什么要使用网关&#xff1f; 在当下流行的微服务架构中&#xff0c;面对多端应用时我们往往会做前后端分离&#xff1a;如前端分成 APP 端、网页端、小程序端等&#xff0c;使用 Vue 等流行的前端框架交给前端团队负责实现&#xff1b;后端拆分成若干微服务&#xff0c;分别…

独立开发变现周刊(第85期):一个会员服务的SaaS,月收入2万美金

分享独立开发、产品变现相关内容&#xff0c;每周五发布。目录1、Obsidian Canvas&#xff1a;一个无限的空间来构建你的想法2、message-pusher: 搭建专属于你的消息推送服务3、Careerflow LinkedIn: 40倍提升你的工作机会4、vue-pure-admin: 一款开源后台管理系统5、一个提供会…

【HarmonyOS】调测助手安装失败10内部错误

关于鸿蒙开发通过应用调测助手向watch gt 3 手表安装hap时报错。 问题背景&#xff1a; 鸿蒙开发&#xff0c;使用新建工程的helloworld 没有其他修改&#xff0c;生成hap包。然后通过应用调测助手向watch gt 3 手表安装hap时提示 安装失败:10.内部错误。 Sdk&#xff1a; a…

基于VUE学生选课管理系统

开发工具(eclipse/idea/vscode等)&#xff1a;idea 数据库(sqlite/mysql/sqlserver等)&#xff1a;mysql 功能模块(请用文字描述&#xff0c;至少200字)&#xff1a; 一、登录注册模块: 1.学生&#xff0c;教师&#xff0c;管理员三个角色&#xff08;同一时刻&#xff0c;账户…

WSL2的安装、应用

WSL2的安装、应用WSL安装、升级常用命令WSL导入导出其他 - 图形界面、虚拟化WSL安装、升级 win10系统上开启WSL参考如下&#xff0c;我先是安装了WSL1&#xff0c;之后又升级到WSL2的。关键是一些Win10上电配置&#xff0c;之后在windows应用商店下载ubuntu即可。 win10上lin…

Python基础(十八):学员管理系统应用

文章目录 学员管理系统应用 一、系统简介 二、步骤分析 三、需求实现 1、显示功能界面 2、用户输入序号&#xff0c;选择功能 3、根据用户选择&#xff0c;执行不同的功能 4、定义不同功能的函数 学员管理系统应用 一、系统简介 需求&#xff1a;进入系统显示系统功能…

跨域问题以及解决跨域问题的vue-cli解决方案

跨域问题 写项目前要问后端,接口支持跨域吗? 支持就不会出现问题,不支持就需要解决跨域问题 1.如何判断一个浏览器的请求是否跨域&#xff1f; 在A地址&#xff08;发起请求的页面地址&#xff09;向B地址&#xff08;要请求的目标页面地址&#xff09;发起请求时&#xff…