PyG MessagePassing机制源码分析

news/2024/4/29 21:15:34/文章来源:https://blog.csdn.net/weixin_42486623/article/details/126816684

PyG MessagePassing机制源码分析


Google在2017发表的论文Neural Message Passing for Quantum Chemistry中提到的Message Passing Neural Networks机制成为了后来图机器学习计算的标准范式实现。

而PyG提供了信息传递(邻居聚合) 操作的框架模型。

其中,
□\square表示 可微、排列不变 的函数,比如说summeanmax
γ\gammaγϕ\phiϕ 表示 可微 的函数,比如说 MLP

在propagate中,依次会调用messageaggregateupdate函数。
其中,
message为公式中 ϕ\phiϕ 部分,表示特征传递
aggregate为公式中 □\square 部分,表示特征聚合
update为公式中 γ\gammaγ 部分,表示特征更新

MessagePassing类

PyG使用MessagePassing类作为实现 信息传递 机制的基类。我们只需要继承其即可。
下面,我们以GCN为例子
GCN信息传递公式如下:

源码分析

一般的图卷积层是通过的forward函数进行调用的,通常的调用顺序如下,那么是如何将自定义的参数kwargs与后续的函数的入参进行对应的呢?(图来源:https://blog.csdn.net/minemine999/article/details/119514944)

MessagePassing初始化构建了Inspector类, 其主要的作用是对子类中自定义的message,aggregate,message_and_aggregate,以及update函数的参数的提取。

class MessagePassing(torch.nn.Module):special_args: Set[str] = {'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size','size_i', 'size_j', 'ptr', 'index', 'dim_size'}def __init__(self, aggr: Optional[str] = "add",flow: str = "source_to_target", node_dim: int = -2,decomposed_layers: int = 1):super().__init__()self.aggr = aggrassert self.aggr in ['add', 'sum', 'mean', 'min', 'max', 'mul', None]self.flow = flowassert self.flow in ['source_to_target', 'target_to_source']self.node_dim = node_dimself.decomposed_layers = decomposed_layersself.inspector = Inspector(self)self.inspector.inspect(self.message)self.inspector.inspect(self.aggregate, pop_first=True)self.inspector.inspect(self.message_and_aggregate, pop_first=True)self.inspector.inspect(self.update, pop_first=True)self.inspector.inspect(self.edge_update)self.__user_args__ = self.inspector.keys(['message', 'aggregate', 'update']).difference(self.special_args)self.__fused_user_args__ = self.inspector.keys(['message_and_aggregate', 'update']).difference(self.special_args)self.__edge_user_args__ = self.inspector.keys(['edge_update']).difference(self.special_args)

inspect函数中,inspect.signature(func).parameters, 获取了子类的函数入参,比如当func="message"时,params = inspect.signature(‘message’).parameters就会获得子类自定义message函数的参数,

class Inspector(object):def __init__(self, base_class: Any):self.base_class: Any = base_classself.params: Dict[str, Dict[str, Any]] = {}def inspect(self, func: Callable,pop_first: bool = False) -> Dict[str, Any]:## 注册func函数的入参,并建立func与入参之间的对应关系params = inspect.signature(func).parametersparams = OrderedDict(params)if pop_first:

参数的传递过程:
从上图可知,参数是从forward传递进来的,而propagate将参数传递后面到对应的函数中,这部分的参数对应关系主要由MessagePassing类的__collect__函数进行参数收集和数据赋值。

__collect__函数中的args主要对应子类中相关函数(message,aggregate,update等)的自定义参数self.__user_args__kwargs为子类的forward函数中调用propagate传递进来的参数。

self.__user_args___i_j后缀是非常重要的参数,其中i表示与target节点相关的参数,j表示source节点相关的参数,其图上的指向为j->i for j 属于N(i),后缀不包含_i_j的参数直接被透传。(默认:self.flow==source_to_target)

def __collect__(self, args, edge_index, size, kwargs):i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)out = {}for arg in args:# 遍历自定义函数中的参数if arg[-2:] not in ['_i', '_j']: # 不包含_i和_j的自定义参数直接透传out[arg] = kwargs.get(arg, Parameter.empty) # 从用户传递进来的kwargs参数中获取值else:dim = 0 if arg[-2:] == '_j' else 1 # 注意这里的取值维度data = kwargs.get(arg[:-2], Parameter.empty) # 取用户传递进来的kwargs前缀arg[:-2]的数据if isinstance(data, (tuple, list)):assert len(data) == 2if isinstance(data[1 - dim], Tensor):self.__set_size__(size, 1 - dim, data[1 - dim])data = data[dim]if isinstance(data, Tensor):self.__set_size__(size, dim, data)data = self.__lift__(data, edge_index,j if arg[-2:] == '_j' else i)out[arg] = dataif isinstance(edge_index, Tensor):out['adj_t'] = Noneout['edge_index'] = edge_indexout['edge_index_i'] = edge_index[i]out['edge_index_j'] = edge_index[j]out['ptr'] = Noneelif isinstance(edge_index, SparseTensor):out['adj_t'] = edge_indexout['edge_index'] = Noneout['edge_index_i'] = edge_index.storage.row()out['edge_index_j'] = edge_index.storage.col()out['ptr'] = edge_index.storage.rowptr()out['edge_weight'] = edge_index.storage.value()out['edge_attr'] = edge_index.storage.value()out['edge_type'] = edge_index.storage.value()out['index'] = out['edge_index_i']out['size'] = sizeout['size_i'] = size[1] or size[0]out['size_j'] = size[0] or size[1]out['dim_size'] = out['size_i']return out

propagate中依次从coll_dict中获取与messageaggregateupdate函数的参数进行调用。注意这里获取的参数是通过上述的self.inspector.distribute函数进行获取的。

def propagate(self,..):##...##...msg_kwargs = self.inspector.distribute('message', coll_dict)out = self.message(**msg_kwargs)##...##...aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)out = self.aggregate(out, **aggr_kwargs)update_kwargs = self.inspector.distribute('update', coll_dict)return self.update(out, **update_kwargs)

自定义 message , aggregate , update

   def message(self, x_i, x_j, norm):# x_j ::= x[edge_index[0]] shape = [E, out_channels]# x_i ::= x[edge_index[1]] shape = [E, out_channels]print("x_j", x_j.shape, x_j)print("x_i: ", x_i.shape, x_i)# norm.view(-1, 1).shape = [E, 1]# Step 4: Normalize node features.return norm.view(-1, 1) * x_jdef aggregate(self, inputs: Tensor, index: Tensor,ptr: Optional[Tensor] = None,dim_size: Optional[int] = None) -> Tensor:# 第一个参数不能变化# index ::= edge_index[1]# dim_size ::= [number of node]print("agg_index: ",index)print("agg_dim_size: ",dim_size)# Step 5: Aggregate the messages.# out.shape = [number of node, out_channels]out = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size)print("agg_out:",out.shape,out)return outdef update(self, inputs: Tensor, x_i, x_j) -> Tensor:# 第一个参数不能变化# inputs ::= aggregate.out# Step 6: Return new node embeddings.print("update_x_i: ",x_i.shape,x_i)print("update_x_j: ",x_j.shape,x_j)print("update_inputs: ",inputs.shape, inputs)return inputs

GCN Demo

from typing import Optional
from torch_scatter import scatter
import torch
import numpy as np
import random
import os
from torch import Tensor
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degreeclass GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super().__init__(aggr='add')  # "Add" aggregation (Step 5).self.lin = torch.nn.Linear(in_channels, out_channels)def forward(self, x, edge_index):# x has shape [N, in_channels]# edge_index has shape [2, E]# Step 1: Add self-loops to the adjacency matrix.edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))# Step 2: Linearly transform node feature matrix.x = self.lin(x) # x = lin(x)# Step 3: Compute normalization.row, col = edge_index # row, col is the [out index] and [in index]deg = degree(col, x.size(0), dtype=x.dtype) # [in_degree] of each node, deg.shape = [N]deg_inv_sqrt = deg.pow(-0.5)deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # deg_inv_sqrt.shape = [E]# Step 4-6: Start propagating messages.return self.propagate(edge_index, x=x, norm=norm)def message(self, x_i, x_j, norm):# x_j ::= x[edge_index[0]] shape = [E, out_channels]# x_i ::= x[edge_index[1]] shape = [E, out_channels]print("x_j", x_j.shape, x_j)print("x_i: ", x_i.shape, x_i)# norm.view(-1, 1).shape = [E, 1]# Step 4: Normalize node features.return norm.view(-1, 1) * x_jdef aggregate(self, inputs: Tensor, index: Tensor,ptr: Optional[Tensor] = None,dim_size: Optional[int] = None) -> Tensor:# 第一个参数不能变化# index ::= edge_index[1]# dim_size ::= [number of node]print("agg_index: ",index)print("agg_dim_size: ",dim_size)# Step 5: Aggregate the messages.# out.shape = [number of node, out_channels]out = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size)print("agg_out:",out.shape,out)return outdef update(self, inputs: Tensor, x_i, x_j) -> Tensor:# 第一个参数不能变化# inputs ::= aggregate.out# Step 6: Return new node embeddings.print("update_x_i: ",x_i.shape,x_i)print("update_x_j: ",x_j.shape,x_j)print("update_inputs: ",inputs.shape, inputs)return inputsdef set_seed(seed=1029):random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = Trueif __name__ == '__main__':set_seed(0)# x.shape = [5, 2]x = torch.tensor([[1,2], [3,4], [3,5], [4,5], [2,6]], dtype=torch.float)# edge_index.shape = [2, 6]edge_index = torch.tensor([[0,1,2,3,1,4], [1,0,3,2,4,1]])print("num_node: ",x.shape[0])print("num_edge: ",edge_index.shape[1])in_channels = x.shape[1]out_channels = 3gcn = GCNConv(in_channels, out_channels)out = gcn(x, edge_index)print(out)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_6569.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

低成本实现webhook接收端[python]

1. Webhook是个啥 Webook本质上也是API,只不过是反向调用,即前端不主动发送请求,完全由后端推送。 简单来说,Webhook是一个接受HTTP POST或是GET,PUT,DELETE的URL,一个实现了Webhook的API提供商就是在当事件发生的时…

基于 Quartz 的调度中心

需求 服务使用集群部署(多Pod)基础服务提供调度任务注册,删除,查看的功能尽可能减少客户端的使用成本开发工作量尽可能少,成本尽可能小 基于以上的需求,设计如下,调度中心非独立部署,集成在base服务中。客…

最优化 | 一维搜索与方程求根 | C++实现

文章目录参考资料前言1. 二分法求根1.1 [a,b]区间二分法求根1.1.1 原理1.1.2 C实现1.2 区间右侧无穷的二分法求根1.3 求含根区间2. 牛顿法求根2.1 原理2.2 c实现3. 梯度下降法求根3.1 c实现4. 一维搜索的区间4.1 一般一维搜索方法4.2 黄金分割法(0.618)4…

K8s部署SpringBoot项目简单例子

目录 前言 前提条件 正文 1. 获取镜像 2. 空运行测试生成部署yaml文件 3. 修改yaml文件,增加镜像拉取策略 4. 以yaml文件的方式部署springboot项目 5. 查看部署pod的状态 6. 暴露服务端口 7.通过浏览器访问服务 前言 本文通过将一个构建好的springboot的…

Linux服务器上通过miniconda安装R(2022)

安装miniconda 下载最新版miniconda wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh安装 bash Miniconda3-latest-Linux-x86_64.sh这一步骤里我们输入完命令后会有个License要读,一行一行读的话按Enter,不想读就直接输…

【uiautomation】获取微信好友名单,可指定标签 全部

前言 接到了一个需求:现微信有8000好友,需要给所有好友发送一则一样的消息。网上搜索一番后,发现uiautomation 可以解决该需求,遂有此文。这是第一篇,获取全部好友 代码在文章末尾,自取~ 微信群发消息链接 …

一个画廊的GIF动画动作英雄从80年代和90年代

你还记得那些80年代和90年代初的动作英雄吗?比如查克诺里斯、史蒂文西格尔、西尔维斯特史泰龙、让克劳德范达姆,当然还有阿诺德施瓦辛格?意大利天才设计师DavideMazuchin&;郭美雄创建了一个图文并茂的GIF画廊,名为“过去的动作英雄”,以纪念那些年轻时的经典英雄。以…

不同vlan之间实现通信

目录: 1、单臂路由实现不同vlan间通信的原理 2、单臂路由的缺陷 3、单臂路由的配置 4、三层交换 不同vlan之间实现通信 单臂路由链路类型:交换机连接主机的端口位为access链路交换机连接路由器的的端口为trunk链路子接口:路由器的物理接口可以被划分成多个逻辑接口每个子接口…

【云原生】Kubernetes CRD 详解(Custom Resource Definition)

文章目录一、概述二、定制资源1)定制资源 和 定制控制器2)定制控制器3)Operator 介绍1、Operator Framework2、Operator 安装3、安装 Operator SDK4、Operator 简单使用4)Kubernetes API 聚合层5)声明式 APIs6&#xf…

HTML 快速入门

HTML代码是“标签化”的代码,把一个HTML文件视为一个文档,文档中有很多的标签,每一个标签也可以称为一个元素,同时每一个元素也对应一个对象,对象中有属性和方法。HTML的标签除了部分标签外,其他的都是成对…

易网防伪防窜货溯源管理系统源码

防伪防窜货和溯源系统更好用更易用,系统由PHPmysql开发,安全稳定。系统以防伪码(溯源码)为中心,可非常方便的为防伪码赋值产品信息,溯源信息。是建立防伪防窜货和溯源追踪系统的不二选择。 系统功能介绍: 一、防伪码管…

【RuoYi-Vue-Plus】学习笔记 40 - Validator(一)校验器对 Model 属性校验调用流程分析

文章目录前言参考目录框架集成1、Maven2、校验框架配置类 ValidatorConfig3、测试方法4、接口测试4.1、校验失败(参数为 null)4.2、校验成功(参数不为 null)执行流程分析InvocableHandlerMethod#invokeForRequestInvocableHandler…

来自邦卡的神奇扁平超级英雄插图

平面设计趋势正在相当大程度上动摇平面设计行业的各个方面。我们正在进入一个简单和最低限度的沟通模式的新时代,在这个时代中,平面设计似乎以最好的方式提供。 受平面设计形式的启发,法国平面设计师邦卡采用了相同的方法,创作了一系列简约、平面的超级英雄插图。这些插图涵…

自制操作系统日志——第二十二天

自制操作系统日志——第二十二天 今天,我们将继续再完善一下保护操作系统的内容,以及进一步的利用c语言显示字符串! 文章目录自制操作系统日志——第二十二天一、保护操作系统3手动强制关闭应用程序二、用c语言显示字符串API 显示窗口总结一…

vivado使用方法(初级)

文章目录1 创建新工程1.1 工程创建1.2 新建Verilog文件1.3 仿真参考1 创建新工程 1.1 工程创建 1、首先打开Vavido软件,点击Creat Project或者在File——>Project——>New里面进行新工程的创建 2、然后在弹出的界面上点击Next进入下一个界面进行项目的命名…

全站最简单 “数据滚动可视化大屏” 【JS基础拿来即用】

源码获取方式: 数据滚动大屏源码,原生js实现超级简单-Javascript文档类资源-CSDN下载原生js实现的数据滚动大屏案例,实现应该是全网最简单的,拿来直接使用即可,没有会员的小伙伴去我文章主更多下载资源、学习资料请访问…

基于Python实现的遗传算法求TSP问题

遗传算法求TSP问题 目录 人工智能第四次实验报告 1 遗传算法求TSP问题 1 一 、问题背景 1 1.1 遗传算法简介 1 1.2 遗传算法基本要素 2 1.3 遗传算法一般步骤 2 二 、程序说明 3 2.3 选择初始群体 4 2.4 适应度函数 4 2.5 遗传操作 4 2.6 迭代过程 4 三 、程序测试 5 3.1 求解…

Vue3+elementplus搭建通用管理系统实例七:通用表格实现上

一、本章内容 使用配置的方式实现表格的界面的自动生成、自动解析实体配置信息,并生成表格列、筛选项等功能,完整课程地址 二、效果预览 三、开发视频

动手实现深度学习(12): 卷积层的实现

9.1 卷积层的运算 传送门: https://www.cnblogs.com/greentomlee/p/12314064.html github: Leezhen2014: https://github.com/Leezhen2014/python_deep_learning 卷积的forward 卷积的计算过程网上的资料已经做够好了,没必要自己再写一遍。只把资料搬运到这里: http://deepl…

【进击的JavaScript|高薪面试必看】JS基础-作用域和闭包

六年代码两茫茫,不思量,自难忘 6年资深前端主管一枚,只分享技术干货,项目实战经验,面试指导 关注博主不迷路~ 本系列文章是博主精心整理的面试热点问题,吸收了大量的技术博客与面试文章,总结多年…