机器学习周记(第三十一周:文献阅读-GGNN)2024.3.18~2024.3.24

news/2024/4/27 19:05:38/文章来源:https://blog.csdn.net/DominaterWE/article/details/136813562

目录

摘要

ABSTRACT

1 论文信息

1.1 论文标题

1.2 论文模型

1.2.1 数据处理

1.2.2 门控图神经网络

1.2.3 掩码操作

2 相关知识

2.1 图神经网络(GNN)

2.2 图卷积神经网络(GCN)

3 相关代码


摘要

  本周阅读了一篇利用图神经网络(GNN)与门控循环单元(GRU)进行配水网络(WDN)水质预测的论文。论文模型(GGNN)实现了扩展图邻接矩阵在有向图中加入双向信息流,从而增强了模型的双向学习能力。同时模型还利用掩码操作模拟了站点故障导致数据缺失的情况,根据正常站点数据也能对故障站点进行预测,并且还能解决模型过拟合或者欠拟合的问题。

ABSTRACT

  This week, We read a paper on water quality prediction in water distribution networks (WDNs) using Graph Neural Networks (GNN) and Gated Recurrent Units (GRU). The paper introduces a model called GGNN, which extends the graph adjacency matrix to incorporate bidirectional information flow in directed graphs, thus enhancing the model's bidirectional learning capability. Additionally, the model utilizes masking operations to simulate data missing due to station failures, enabling the prediction of faulty stations based on normal station data. Moreover, it addresses the issues of model overfitting or underfitting.

1 论文信息

1.1 论文标题

Real-time water quality prediction in water distribution networks using graph neural networks with sparse monitoring data

1.2 论文模型

  论文模型(GGNN)旨在利用门控图神经网络(GGNN)处理网络拓扑结构、流向以及水质监测站的历史氯浓度测量数据来预测配水网络(WDN)中的实时水质。该模型由两个主要部分组成:(1)对供水网络信息进行数据处理,输入到图神经网络中;(2)利用收集到的数据构建模型。

Fig.1 基于GGNN的实时水质预测方法示意图

1.2.1 数据处理

  GGNN模型需要两类数据:传感器监测站的WDN拓扑结构和历史水质监测数据。假设一个WDNn个节点和m条管道组成,配备N_{s}个传感器站监测水质。网络拓扑由图G=(V,E)表示,其中V表示由水库、储罐和连接点组成的节点集,E表示由管道、阀门和泵组成的边集。网络的流向信息和空间拓扑细节通常可以从EPANET等水力模型中获得。利用这些数据构建有向图的邻接矩阵A \in \mathbb{R}^{n\times n},其中每个元素A_{ij}表示水是否从节点i流向节点j (A_{ij}=1)或不流向 (A_{ij}=0)。论文仅在边的权重相等时考虑水流方向。更进一步还可以同时考虑流量的动态变化和加权边。

  通过在WDN中实现的监控和数据采集(SCADA)系统,可以获得各监测站的历史水质数据。该数据采集过程包括在指定的时间窗口内采集水质测量数据,记为T_{c},也表示采集历史数据的周期时间。然后将采集到的数据作为数据集中被监测节点的节点属性,对于未被监测节点,将空值替换为0,得到节点属性X\in \mathbb{R}^{n \times N_{c}}N_{c}表示数据采集周期T_{c}内获得的水质测量次数,对应于指定时间窗口内的时间步数。它是预测下一时刻水质所需数据大小的指标。

1.2.2 门控图神经网络

  为了解决WDN的非欧氏图域带来的挑战,将GGNN架构用于水质预测。GGNN是一种图神经网络,用于处理复杂的图结构数据,如WDN拓扑。它扩展了通常定义在欧氏域上的传统神经网络,使其能够直接处理非欧氏图数据。GGNN模型根据相邻节点和边之间传递的消息为每个节点v\in V计算状态向量h_{v}。状态向量h_{v}表示节点学习到的特征表示,编码了关于图的局部和全局信息。它可以被认为是节点的隐藏状态,从其邻域和整个图中捕获相关信息。最终,状态向量可用于水质预测。GGNN的整体工作流程如Fig.2所示。

Fig.2 GGNN总体架构示意图

  首先,通过扩展邻接矩阵A \in \mathbb{R}^{n \times n},在有向图中加入双向信息流来作为输入。主要通过将邻接矩阵A与其转置连接起来,形成一个扩展的邻接矩阵\widehat{A}=\left [ A,A^{T} \right ]来实现的,这样可以同时考虑输入边和输出边。\widehat{A} \in \mathbb{R}^{n \times 2n}捕获了节点之间的复杂关系和消息传播方向,从而增强了GGNN的双向学习能力。

  然后,通过标准线性组合修正线性单元(rectified linear unit, ReLU)激活函数将节点v的节点属性x_{v}从原始空间\mathbb{R}^{N_{c}}映射到新空间\mathbb{R}^{M}的原始隐藏状态h_{v}^{(0)}。这种映射过程有效地扩大了节点属性的大小,使GGNN能够捕获节点属性之间潜在的重要非线性关系。隐藏状态的大小用M表示,是一个决定模型容量的超参数。然而,至关重要的是要与M取得平衡,以防止过拟合并控制训练期间的计算复杂性。

  GGNN以扩展的邻接矩阵\widehat{A}=\left [ A,A^{T} \right ]和映射的节点属性h^{(0)}为输入,在固定的k步上递归计算节点状态以产生最终的状态矩阵h^{(K)}\in \mathbb{R}^{n \times M}。在聚合阶段,利用扩展邻接矩阵\widehat{A}计算聚合向量a_{v}a_{v}表示节点v和相邻节点状态的聚合,聚合向量的计算公式如下:

a_{v}^{(k)}=\widehat{A}^{T}_{v:}\left [ h_{1}^{(k-1)^{T}},...,h_{n}^{(k-1)^{T}} \right ]^{T}+b                                                                              (1)

其中,上标k表示时间步长,\widehat{A}_{v:}\in \mathbb{R}^{n \times 2}是块\widehat{A}中对应节点v的两列,b是偏移向量。在聚合阶段之后,传播阶段采用门控循环单元(gated recurrent units, GRU)机制更新节点状态。GRU传播方程描述如下:

r_{v}^{(k)}=\sigma (W_{r} \cdot a_{v}^{(k)}+U_{r}\cdot h_{v}^{(k-1)})                                                                                       (2)

z_{v}^{(k)}=\sigma (W_{z} \cdot a_{v}^{(k)}+U_{z}\cdot h_{v}^{(k-1)})                                                                                       (3)

\widetilde{h}_{v}^{(k)}=\tanh (W \cdot a_{v}^{(k)}+U\cdot (r_{v}^{(k)}\bigodot h_{v}^{(k-1)}))                                                                   (4)

h_{v}^{(k)}=(1-z_{v}^{(k)})\bigodot h_{v}^{(k-1)}+z_{v}^{(k)}\bigodot \widetilde{h}_{v}^{(k)}                                                                         (5)

其中rz是重置门和更新门;W_{r},W_{z},WU_{r},U_{z},U是每层的权重和偏差;\sigma (\cdot)sigmoid激活函数;\bigodot是元素点积运算。

  GGNN中的聚合和传播步骤允许模型迭代更新和细化节点状态,合并来自节点先前的特征及其邻近节点的特征信息。这个迭代过程捕获了图结构内的动态和交互规则,使GGNN能够学习和表示节点之间的复杂关系和依赖关系。传播步长K(也即GNN层数)决定了GGNN中信息传播的深度。当K=1时,每个节点只能从其近邻节点学习。随着K的增加,GGNN可以从距离K步的节点捕获信息,包括它们的间接连接。K的选择影响模型的学习能力和效率。较高的K值会导致训练较慢以及增加内存需求,而较低的K值会限制每个节点可以学习的依赖关系的数量。因此,K的选择应该在模型性能和计算效率之间取得平衡。

  在使用GRU模块更新节点状态后,使用线性层将更新后的状态h^{(K)}转换为表示每个节点预测状态的\widehat{Y}\in \mathbb{R}^{n}。在本研究中,节点属性为历史水质浓度数据,其预测状态表示模型对每个节点下一时间步水质浓度的预测。这种转换允许模型根据其更新的表示和从邻近节点传播的信息在每个节点生成对水质的预测。

1.2.3 掩码操作

  虽然之前的研究主要采用掩码操作(Maskng Operation来模拟传感器故障,特别是在不利条件下测试模型的鲁棒性,但本文方法在训练阶段利用掩码操作来增强模型对未监测节点的预测能力。在训练过程中,结合掩码操作对解决两个重大挑战至关重要。首先,现有研究通常假设传感器节点的输入,并根据模拟的网络中所有节点的值来计算损失,这在现实世界中是不切实际的,因为获取非传感器节点的测量数据很困难。论文使用模拟模型的合成数据,这样数据虽然完整,但作者并没有使用所有网络节点的所有数据进行训练。相反,只使用了一小部分节点数据。其次,如果模型仅基于传感器节点的输入进行训练,并基于这些节点计算损失,可能会导致过拟合,阻碍模型预测未监测节点的水质的能力。为了克服这些挑战,在训练过程中引入了掩码操作。随机选择指定比例(例如20%)的传感器节点,并通过在每个训练批次中将其输入替换为零进行掩盖。这个屏蔽操作有两个目的。首先,在训练过程中模拟非传感器节点数据的不可用性,使模型能够在观测到的传感器数据之外进行泛化,并学习预测无监测节点的值;其次,它作为正则化技术,防止模型仅依赖有限的传感器输入。通过鼓励模型捕捉传感器节点和非监测节点之间的关系,提高模型的泛化能力,降低过拟合的可能性。需要研究掩码节点的比例,因为它可以平衡模型性能和过拟合。更高的比率会减少可用的信息,增加欠拟合的风险。较低的速率可以提供更多的信息,但可能会导致过拟合。因此,掩码率也是一个十分重要的超参数。

2 相关知识

2.1 图神经网络(GNN)

2.2 图卷积神经网络(GCN)

  需要注意的是,常规任务情境下不会需要节点的信息传播太远。经过6~7个hops,基本上就可以使节点的信息传播到整个网络,这也使得聚合不那么有意义。实验结果也表明,2~3层的网络应该是比较好的,当GCN达到7层时,效果已经变得较差,但是通过在隐藏层间加上残差连接(Residual Connections)可以使效果变好。

3 相关代码

GCN模型定义与图结构数据定义:

import torch
import torch.nn as nn
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkxclass GCN(nn.Module):def __init__(self):super().__init__()torch.manual_seed(1234)self.conv1 = GCNConv(dataset.num_features, 4)self.conv2 = GCNConv(4, 4)self.conv3 = GCNConv(4, 2)self.classifier = nn.Linear(2, dataset.num_classes)def forward(self, x, edge_index):h = self.conv1(x, edge_index)  # 输入特征与邻接矩阵h = h.tanh()h = self.conv2(h, edge_index)h = h.tanh()h = self.conv3(h, edge_index)h = h.tanh()out = self.classifier(h)return out, hdef visualize_graph(G, color):plt.figure(figsize=(7, 7))plt.xticks([])plt.yticks([])nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False, node_color=color, cmap="Set2")plt.show()def visualize_embedding(h, color, epoch=None, loss=None):plt.figure(figsize=(7, 7))plt.xticks([])plt.yticks([])h = h.detach().cpu().numpy()plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")if epoch is not None and loss is not None:plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)plt.show()dataset = KarateClub()
print(f'Dataset: {dataset}')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')data = dataset[0]
# x:[34, 34](M*F,M:样本数,F:特征维度)
# edge_index:[2, 156](两个数组,第一个为source,第二个为target,156条边)
# y:[34](标签)
# train_mask:[34](指定节点是否有标签,通过此数组可以选择哪些节点计算损失,元素类型为bool)
print(data)
print(dataset.edge_index)G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

数据集KarateClub的图结构:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1026576.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IDEA2023版本整合SpringBoot热部署

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏: 开发环境篇 ✨特色专栏: M…

波奇学Linux:自定义协议和序列和反序列化

TCP是面向字节流的如何保证,读取上来的数据是一个"完整"的报文 tcp传输控制协议:什么时候发,发多少,出错怎么办 read和write都是从用户到内核空间的拷贝,数据不一定传输到另一个台主机的缓冲区,…

精品凉拌菜系列热卤系列课程

这一系列课程涵盖精美凉拌菜和美味热卤菜的制作技巧。学员将学习如何选材、调味和烹饪,打造口感丰富、色香俱佳的菜肴。通过实践训练,掌握独特的烹饪技能,为家庭聚餐或职业厨艺提升增添亮点。 课程大小:6.6G 课程下载&#xff1…

IDEA 远程调试

1.什么是远程调试 Java提供了一个远程调试功能,支持设置断点及线程级的调试同时,不同的JVM通过接口的协议联系,本地的Java文件在远程JVM建立联系和通信。 2.服务端开启远程调试 开启远程调试功能,需要修改tomcat 的catalina.sh…

链表队列LinkQueue

入队:往尾巴上放 1.先定义一个新节点,指针置空 2. 连接 3. 移动尾指针 出队:从头部出队 1. 定义一个temp指针 2. head指针指向下一个 3. 通过free 释放temp指针所指 4. 若指完后,head所指为NULL,则把尾指…

Windows前后端部署(达梦,东方通)

打开虚拟机,打开远程路径服务 将素材复制到虚拟机并解压(jkd,达梦,东方通,nginx) 双击安装jdk(一直下一步) 安装达梦 双击 直接下载完成 东方通下载(双击程序) 改端口号8080 把许可文件放到东方…

设计模式——观察者模式Observer

Q:观察者模式属于哪一类设计模式 A:观察者模式属于行为学模式 Q:什么是观察者模式 A:当一个对象的状态发生改变时,所有依赖它的对象都得到通知,并自动更新 观察者模式解析:报纸类维护了一个…

文心一言 VS 讯飞星火 VS chatgpt (224)-- 算法导论16.3 6题

六、假定我们有字母表 C{0,1,…,n-1} 上的一个最优前缀码,我们希望用最少的二进制位传输此编码。说明如何仅用 2n-1n⌈lgn⌉ 位表示 C 上的任意最优前缀码。(提示:通过对树的遍历,用 2n-1 位说明编码树的结…

第十四届蓝桥杯省赛C++ A组所有题目以及题解(C++)【编程题均通过100%测试数据】

第一题《幸运数》【模拟】 【问题描述】 小蓝认为如果一个数含有偶数个数位,并且前面一半的数位之和等于后面一半的数位之和,则这个数是他的幸运数字。例如 2314是一个幸运数字,因为它有4个数位,并且2314。现在请你帮他计算从1至100000000之间共有多少…

Flink RPC初探

1.RPC概述 RPC( Remote Procedure Call ) 的主要功能目标是让构建分布式计算(应用)更容易,在提供强大的远程调用能力时不损失本地调用的语义简洁性。 为实现该目标,RPC 框架需提供一种透明调用机制让使用者不必显式的区分本地调用和远程调用。 总而言之&…

jenkins权限分配

1.安装权限插件 Role-Based Strategy 2.创建用户 3.修改全局安全配置中的授权策略为Role-Based Strategy 4.进入Manage and Assign Roles创建Global roles和Item roles 4.进入Assign Roles给用户分配role

UI风格汇:材料设计(Material Design),是对扁平风格的延展。

Hello,我是大千UI工场,设计风格是我们新开辟的栏目,主要讲解各类UI风格特征、辨识方法、应用场景、运用方法等,本次带来的材料风格风格的解读,有设计需求,我们也可以接单。 一、什么是材料设计(…

apisix创建https

总结了下apisix 使用https 的问题和方法 1、apisix 默认https 端口是9443 2、apisix 需要上传证书后才可以使用https 否二curl测试会报错 SSL routines:CONNECT_CR_SRVR_HELLO 3、apisix 上传证书方法 我是使用的自签名证书,注意自签名证书的Common Name 要写你…

静态路由表学习实验

实验要求:各个pc设备可以通信,并且可以访问外网,假设R1已连接外网 拓扑结构 思路:配置pc机ip地址,子网掩码,和网关(网关地址是上层路由接口的地址),配置路由各个接口地址…

【Qt】使用Qt实现Web服务器(七):动态模板引擎

1、示例 2、源码 2.1 模板配置参数 配置文件中关于模板配置参数如下 path为存放模板的目录suffix为模板文件后缀[templates] path=templates suffix=.tpl encoding=UTF-8 cacheSize=1000000

OpenHarmony开发知识点记录之ABI

OpenHarmony系统支持丰富的设备形态,支持多种架构指令集,支持多种操作系统内核;为了应用在各种OpenHarmony设备上的兼容性,本文定义了"OHOS" ABI(Application Binary Interface)的基础标准&#…

缓冲区溢出漏洞相关知识点汇总

1.缓冲区基础知识相关定义 缓冲区定义:缓冲区一块连续的内存区域,用于存放程序运行时,加载到内存的运行代码和数据。 缓冲区溢出:缓冲区溢出是指程序运行时,向固定大小的缓冲区写入超过其容量的数据。多余的数据会越…

Java代码基础算法练习-求一个三位数的各位数字之和-2024.03.27

任务描述&#xff1a; 输入一个正整数n&#xff08;取值范围&#xff1a;100<n<1000&#xff09;&#xff0c;然后输出每位数字之和 任务要求&#xff1a; 代码示例&#xff1a; package M0317_0331;import java.util.Scanner;public class m240327 {public static voi…

Abaqus周期性边界代表体单元Random Sphere RVE 3D (Mesh)插件

插件介绍 Random Sphere RVE 3D (Mesh) - AbyssFish 插件可在Abaqus生成三维具备周期性边界条件(Periodic Boundary Conditions, PBC)的随机球体骨料及骨料-水泥界面过渡区(Interfacial Transition Zone, ITZ)模型。即采用周期性代表性体积单元法(Periodic Representative Vol…

信号量,sem_init/wait/post/destroy函数的使用

sem_init&#xff08;&#xff09;&#xff1b;--------------------------------------------------------------------------------------- 信号量的初始化函数定义在线程创建之前&#xff0c;资源变量定义为全局变量 一开始只有一个写资源&#xff0c;没有读资源 sem_wait(…