PyG MessagePassing机制源码分析
Google在2017发表的论文Neural Message Passing for Quantum Chemistry中提到的Message Passing Neural Networks机制成为了后来图机器学习计算的标准范式实现。
而PyG提供了信息传递(邻居聚合) 操作的框架模型。
其中,
□\square□表示 可微、排列不变 的函数,比如说sum、mean、max
γ\gammaγ 和 ϕ\phiϕ 表示 可微 的函数,比如说 MLP
在propagate中,依次会调用message,aggregate,update函数。
其中,
message为公式中 ϕ\phiϕ 部分,表示特征传递
aggregate为公式中 □\square□ 部分,表示特征聚合
update为公式中 γ\gammaγ 部分,表示特征更新
MessagePassing类
PyG使用MessagePassing类作为实现 信息传递 机制的基类。我们只需要继承其即可。
下面,我们以GCN为例子
GCN信息传递公式如下:
源码分析
一般的图卷积层是通过的forward函数进行调用的,通常的调用顺序如下,那么是如何将自定义的参数kwargs与后续的函数的入参进行对应的呢?(图来源:https://blog.csdn.net/minemine999/article/details/119514944)
MessagePassing初始化构建了Inspector类, 其主要的作用是对子类中自定义的message,aggregate,message_and_aggregate,以及update函数的参数的提取。
class MessagePassing(torch.nn.Module):special_args: Set[str] = {'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size','size_i', 'size_j', 'ptr', 'index', 'dim_size'}def __init__(self, aggr: Optional[str] = "add",flow: str = "source_to_target", node_dim: int = -2,decomposed_layers: int = 1):super().__init__()self.aggr = aggrassert self.aggr in ['add', 'sum', 'mean', 'min', 'max', 'mul', None]self.flow = flowassert self.flow in ['source_to_target', 'target_to_source']self.node_dim = node_dimself.decomposed_layers = decomposed_layersself.inspector = Inspector(self)self.inspector.inspect(self.message)self.inspector.inspect(self.aggregate, pop_first=True)self.inspector.inspect(self.message_and_aggregate, pop_first=True)self.inspector.inspect(self.update, pop_first=True)self.inspector.inspect(self.edge_update)self.__user_args__ = self.inspector.keys(['message', 'aggregate', 'update']).difference(self.special_args)self.__fused_user_args__ = self.inspector.keys(['message_and_aggregate', 'update']).difference(self.special_args)self.__edge_user_args__ = self.inspector.keys(['edge_update']).difference(self.special_args)
inspect函数中,inspect.signature(func).parameters
, 获取了子类的函数入参,比如当func="message"时,params = inspect.signature(‘message’).parameters就会获得子类自定义message函数的参数,
class Inspector(object):def __init__(self, base_class: Any):self.base_class: Any = base_classself.params: Dict[str, Dict[str, Any]] = {}def inspect(self, func: Callable,pop_first: bool = False) -> Dict[str, Any]:## 注册func函数的入参,并建立func与入参之间的对应关系params = inspect.signature(func).parametersparams = OrderedDict(params)if pop_first:
参数的传递过程:
从上图可知,参数是从forward传递进来的,而propagate将参数传递后面到对应的函数中,这部分的参数对应关系主要由MessagePassing类的__collect__
函数进行参数收集和数据赋值。
__collect__函数中的args主要对应子类中相关函数(message,aggregate,update等)的自定义参数self.__user_args__
,kwargs
为子类的forward函数中调用propagate传递进来的参数。
self.__user_args__
中_i
和_j
后缀是非常重要的参数,其中i
表示与target节点相关的参数,j
表示source节点相关的参数,其图上的指向为j->i for j
属于N(i)
,后缀不包含_i
和_j
的参数直接被透传。(默认:self.flow==source_to_target
)
def __collect__(self, args, edge_index, size, kwargs):i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)out = {}for arg in args:# 遍历自定义函数中的参数if arg[-2:] not in ['_i', '_j']: # 不包含_i和_j的自定义参数直接透传out[arg] = kwargs.get(arg, Parameter.empty) # 从用户传递进来的kwargs参数中获取值else:dim = 0 if arg[-2:] == '_j' else 1 # 注意这里的取值维度data = kwargs.get(arg[:-2], Parameter.empty) # 取用户传递进来的kwargs前缀arg[:-2]的数据if isinstance(data, (tuple, list)):assert len(data) == 2if isinstance(data[1 - dim], Tensor):self.__set_size__(size, 1 - dim, data[1 - dim])data = data[dim]if isinstance(data, Tensor):self.__set_size__(size, dim, data)data = self.__lift__(data, edge_index,j if arg[-2:] == '_j' else i)out[arg] = dataif isinstance(edge_index, Tensor):out['adj_t'] = Noneout['edge_index'] = edge_indexout['edge_index_i'] = edge_index[i]out['edge_index_j'] = edge_index[j]out['ptr'] = Noneelif isinstance(edge_index, SparseTensor):out['adj_t'] = edge_indexout['edge_index'] = Noneout['edge_index_i'] = edge_index.storage.row()out['edge_index_j'] = edge_index.storage.col()out['ptr'] = edge_index.storage.rowptr()out['edge_weight'] = edge_index.storage.value()out['edge_attr'] = edge_index.storage.value()out['edge_type'] = edge_index.storage.value()out['index'] = out['edge_index_i']out['size'] = sizeout['size_i'] = size[1] or size[0]out['size_j'] = size[0] or size[1]out['dim_size'] = out['size_i']return out
propagate中依次从coll_dict
中获取与message,aggregate,update函数的参数进行调用。注意这里获取的参数是通过上述的self.inspector.distribute
函数进行获取的。
def propagate(self,..):##...##...msg_kwargs = self.inspector.distribute('message', coll_dict)out = self.message(**msg_kwargs)##...##...aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)out = self.aggregate(out, **aggr_kwargs)update_kwargs = self.inspector.distribute('update', coll_dict)return self.update(out, **update_kwargs)
自定义 message , aggregate , update
def message(self, x_i, x_j, norm):# x_j ::= x[edge_index[0]] shape = [E, out_channels]# x_i ::= x[edge_index[1]] shape = [E, out_channels]print("x_j", x_j.shape, x_j)print("x_i: ", x_i.shape, x_i)# norm.view(-1, 1).shape = [E, 1]# Step 4: Normalize node features.return norm.view(-1, 1) * x_jdef aggregate(self, inputs: Tensor, index: Tensor,ptr: Optional[Tensor] = None,dim_size: Optional[int] = None) -> Tensor:# 第一个参数不能变化# index ::= edge_index[1]# dim_size ::= [number of node]print("agg_index: ",index)print("agg_dim_size: ",dim_size)# Step 5: Aggregate the messages.# out.shape = [number of node, out_channels]out = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size)print("agg_out:",out.shape,out)return outdef update(self, inputs: Tensor, x_i, x_j) -> Tensor:# 第一个参数不能变化# inputs ::= aggregate.out# Step 6: Return new node embeddings.print("update_x_i: ",x_i.shape,x_i)print("update_x_j: ",x_j.shape,x_j)print("update_inputs: ",inputs.shape, inputs)return inputs
GCN Demo
from typing import Optional
from torch_scatter import scatter
import torch
import numpy as np
import random
import os
from torch import Tensor
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degreeclass GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super().__init__(aggr='add') # "Add" aggregation (Step 5).self.lin = torch.nn.Linear(in_channels, out_channels)def forward(self, x, edge_index):# x has shape [N, in_channels]# edge_index has shape [2, E]# Step 1: Add self-loops to the adjacency matrix.edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))# Step 2: Linearly transform node feature matrix.x = self.lin(x) # x = lin(x)# Step 3: Compute normalization.row, col = edge_index # row, col is the [out index] and [in index]deg = degree(col, x.size(0), dtype=x.dtype) # [in_degree] of each node, deg.shape = [N]deg_inv_sqrt = deg.pow(-0.5)deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # deg_inv_sqrt.shape = [E]# Step 4-6: Start propagating messages.return self.propagate(edge_index, x=x, norm=norm)def message(self, x_i, x_j, norm):# x_j ::= x[edge_index[0]] shape = [E, out_channels]# x_i ::= x[edge_index[1]] shape = [E, out_channels]print("x_j", x_j.shape, x_j)print("x_i: ", x_i.shape, x_i)# norm.view(-1, 1).shape = [E, 1]# Step 4: Normalize node features.return norm.view(-1, 1) * x_jdef aggregate(self, inputs: Tensor, index: Tensor,ptr: Optional[Tensor] = None,dim_size: Optional[int] = None) -> Tensor:# 第一个参数不能变化# index ::= edge_index[1]# dim_size ::= [number of node]print("agg_index: ",index)print("agg_dim_size: ",dim_size)# Step 5: Aggregate the messages.# out.shape = [number of node, out_channels]out = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size)print("agg_out:",out.shape,out)return outdef update(self, inputs: Tensor, x_i, x_j) -> Tensor:# 第一个参数不能变化# inputs ::= aggregate.out# Step 6: Return new node embeddings.print("update_x_i: ",x_i.shape,x_i)print("update_x_j: ",x_j.shape,x_j)print("update_inputs: ",inputs.shape, inputs)return inputsdef set_seed(seed=1029):random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = Trueif __name__ == '__main__':set_seed(0)# x.shape = [5, 2]x = torch.tensor([[1,2], [3,4], [3,5], [4,5], [2,6]], dtype=torch.float)# edge_index.shape = [2, 6]edge_index = torch.tensor([[0,1,2,3,1,4], [1,0,3,2,4,1]])print("num_node: ",x.shape[0])print("num_edge: ",edge_index.shape[1])in_channels = x.shape[1]out_channels = 3gcn = GCNConv(in_channels, out_channels)out = gcn(x, edge_index)print(out)