【RecBole-GNN/源码】RecBole-GNN中lightGCN源码解析

news/2024/4/19 5:29:06/文章来源:https://blog.csdn.net/qq_36931982/article/details/129151295

如果觉得我的分享有一定帮助,欢迎关注我的微信公众号 “码农的科研笔记”,了解更多我的算法和代码学习总结记录。或者点击链接扫码关注【RecBole-GNN/源码】RecBole-GNN中lightGCN源码解析

【RecBole-GNN/源码】RecBole-GNN中lightGCN源码解析


原文:https://arxiv.org/pdf/2002.02126.pdf

源码:伯乐工具箱

LightGCN架构图

输入数据源(图节点仅仅使用了用户或者物品的ID进行模型搭建):

  • ml-1m.inter
  • ml-1m.item
  • ml-1m.user

GCN聚合消息需要定义节点特征以及边

1 节点

节点特征(是需要经过训练得到合适的embedding):得到所有节点特征all_embeddings(9748(6041+3707)*64)

#定义user嵌入:6041*64
self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
#定义item嵌入:3707*64
self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
user_embeddings = self.user_embedding.weight
item_embeddings = self.item_embedding.weight
#进行组合得到:9748(6041+3707)*64
all_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)

2 边

得到所有边edge_index(1610886-1) 以及权重 edge_weight(1610886-1)

#根据.iter交互文件,获取user_id那一列作为row(805443*1)
row = self.inter_feat[self.uid_field]
#根据.iter交互文件,获取item_id那一列作为col(计数id需要加self.user_num)(805443*1)
col = self.inter_feat[self.iid_field] + self.user_num
edge_index1 = torch.stack([row, col])
edge_index2 = torch.stack([col, row])
#得到所有边矩阵2*1610886(805443+805443)
# row col //因为边是双向的
# col row 
edge_index = torch.cat([edge_index1, edge_index2], dim=1)
# 获得每个节点的度(节点的连边)
deg = degree(edge_index[0], self.user_num + self.item_num)
#对于每个节点,如果其度数为 $0$,则将其规范化因子设为 $1$,否则将其规范化因子设为 $1/\sqrt{\text{degree}}$。最终,得到的 #norm_deg 张量表示了每个节点的规范化因子。
norm_deg = 1. / torch.sqrt(torch.where(deg == 0, torch.ones([1]), deg))
#为每条边计算一个权重,该权重等于该边两个节点的规范化因子之积。(1610886*1)
edge_weight = norm_deg[edge_index[0]] * norm_deg[edge_index[1]]

3 GCN聚合

for layer_idx in range(self.n_layers):all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)embeddings_list.append(all_embeddings)
#多轮嵌入求均值
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)
#获得user和item节点的最终嵌入表示
user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])

self.propagate(edge_index, x=x, edge_weight=edge_weight) 是 PyTorch Geometric(简称 PyG)库中定义的一个函数。该函数的作用是对输入的节点特征矩阵 x 进行消息传递,更新节点特征矩阵,并返回更新后的节点特征矩阵。

其中,edge_index 是一个形状为 2×E2 \times E2×E 的张量,表示图中所有边的起始节点和结束节点的编号,EEE 表示边的数量;x 是一个形状为 N×FN \times FN×F 的节点特征矩阵,表示图中所有 NNN 个节点的特征,FFF 表示每个节点的特征向量的维度;edge_weight 是一个形状为 EEE 的张量,表示图中每条边的权重。

在该函数中,消息传递的方式是通过定义一个 message 函数和一个 update 函数来实现的。message 函数的作用是将源节点的特征和边权重作为输入,计算出每条边传递的消息;update 函数的作用是将每个节点收到的消息进行聚合,并更新节点的特征。

具体来说,该函数中的 propagate 函数会对输入的 xedge_weight 执行消息传递,按照以下步骤进行:

  1. 根据输入的 edge_indexedge_weight 构造一个稀疏权重矩阵 edge_index,形状为 N×NN \times NN×N,其中 NNN 表示节点数量,矩阵中的每个元素表示一条边的权重。
  2. 调用 message 函数,将源节点的特征和边权重作为输入,计算出每条边传递的消息。
  3. 将每个节点收到的消息进行聚合,并更新节点的特征。具体来说,对于每个节点 iii,将其所有邻居节点 jjj 的消息按照一定的方式聚合起来,得到一个新的特征向量,用于更新节点 iii 的特征。
  4. 返回更新后的节点特征矩阵。

在实际应用中,propagate 函数通常会被多次调用,用于实现多轮消息传递,并最终得到图中所有节点的特征表示。

4 推荐任务

#获得正例和负例的各自embedding
u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]# calculate BPR Loss
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
mf_loss = self.mf_loss(pos_scores, neg_scores)# calculate regularization Loss
u_ego_embeddings = self.user_embedding(user)
pos_ego_embeddings = self.item_embedding(pos_item)
neg_ego_embeddings = self.item_embedding(neg_item)reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)
loss = mf_loss + self.reg_weight * reg_loss

5 实验

  • 和NGCF进行实验对比:
  • 和最优模型进行对比:NGCF、Mult-VAE、GRMF
  • 消融实验:证明了非线性激活和特征转换这些GCN的结构在推荐系统中并不适用,这很可能是因为推荐系统中每个图节点仅仅使用了用户或者物品的ID进行模型搭建和训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_72215.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ardiuno-交通灯

LED交通灯实验实验器件:■ 红色LED灯:1 个■ 黄色LED灯:1 个■ 绿色LED灯:1 个■ 220欧电阻:3 个■ 面包板:1 个■ 多彩杜邦线:若干实验连线1.将3个发光二极管插入面包板,2.用杜邦线…

【JUC2022】第二章 多线程锁

【JUC2022】第二章 多线程锁 文章目录【JUC2022】第二章 多线程锁一、乐观锁与悲观锁1.悲观锁2.乐观锁二、八锁案例1.标准情况,有a、b两个线程,请问先打印邮件还是短信【结果:邮件】2.sendEmail方法中加入暂停3秒钟,请问先打印邮件…

华为OD机试 - 最小传递延迟(C++) | 附带编码思路 【2023】

刷算法题之前必看 参加华为od机试,一定要注意不要完全背诵代码,需要理解之后模仿写出,通过率才会高。 华为 OD 清单查看地址:https://blog.csdn.net/hihell/category_12199283.html 华为OD详细说明:https://dream.blog.csdn.net/article/details/128980730 华为OD机试题…

随机数与蒙特卡洛方法及Python实现

0 建议学时 4学时 1 引入 1.1 随机数与采样 客观世界的某些行为,结果具有随机性: 掷骰子、投硬币; 等待公交车的时间; 种子发芽的比例; … 1.2 随机数函数 1.2.1 random模块 Python的random模块中提供了若干生成…

RFID盘点软件为企业提供RFID固定资产管理方案

随着科技的发展,固定资产管理系统也经过了一些变革,从刚开始的单机版逐渐发展成SaaS版本,物联网版本等。从刚开始只支持条形码到支持二维码、RFID码。RFID固定资产管理系统上线后,通过给每个实物资产绑定一个RFID码标签后&#xf…

接口测试流程是怎样的?

接口测试流程是怎样的?总所周知,接口测试流程是怎样的?总所周知接口测试在软件测试中是一个非常重要的一部分,其主要目的是测试应用程序的接口是否能够按照规范要求与其他系统或组件进行交互,以及在不同负载条件下接口…

推荐一款新的自动化测试框架:DrissionPage

今天给大家推荐一款基于Python的网页自动化工具:DrissionPage。这款工具既能控制浏览器,也能收发数据包,甚至能把两者合而为一,简单来说:集合了WEB浏览器自动化的便利性和 requests 的高效率。 一、DrissionPage产生背…

vue3-element-admin搭建

vue3-element-admin 是基于 vue-element-admin 升级的 Vue3 Element Plus 版本的后台管理前端解决方案,是 有来技术团队 继 youlai-mall 全栈开源商城项目的又一开源力作功能清单技术栈清单技术栈 描述官网Vue3 渐进式 JavaScript 框架 https://v3.cn.vuejs.org/Ty…

经纬度坐标点和距离之间的转换

1.纬度相同,经度不同 在纬度相同的情况下: 经度每隔0.00001度,距离相差约1米; 每隔0.0001度,距离相差约10米; 每隔0.001度,距离相差约100米; 每隔0.01度,距离相差约1000米…

基于龙芯 2K1000 的嵌入式 Linux 系统移植和驱动程序设计(一)

2.1 需求分析 本课题以龙芯 2K1000 处理器为嵌入式系统的处理器,需要实现一个完成的嵌入式软件系统,系统能够正常启动并可以稳定运行嵌入式 Linux。设计网络设备驱 动,可以实现板卡与其他网络设备之间的网络连接和文件传输。设计 PCIE 设备驱…

我的 System Verilog 学习记录(1)

引言 技多不压身,准备开始学一些 System Verilog 的东西,充实一下自己,这个专栏的博客就记录学习、找资源的一个过程,希望可以给后来者一些借鉴吧,IC找工作的都加把油! 本文是准备先简单介绍一下环境搭建…

洛谷P1125 [NOIP2008 提高组] 笨小猴 C语言/C++

[NOIP2008 提高组] 笨小猴 题目描述 笨小猴的词汇量很小,所以每次做英语选择题的时候都很头疼。但是他找到了一种方法,经试验证明,用这种方法去选择选项的时候选对的几率非常大! 这种方法的具体描述如下:假设 maxn\…

JAVA集合之并发集合

从Java 5 开始,在java.util.concurrent 包下提供了大量支持高效并发访问的集合接口和实现类,如下图所示: 以CopyOnWrite开头的集合即写时复制的容器。通俗的理解是当我们往一个容器添加元素的时候,不直接往容器添加,而…

直播预告 | 嵌入式BI如何将数据分析真正融入业务流程

在信息化高速发展的今天,数据成为企业最有价值的资产之一。而数据本身很难直接传递有价值的信息,只有通过对数据进行挖掘、分析,才能让数据真正成为生产力。 商业智能(BI)应运而生,可以帮助企业更好地从数…

Julia 交互式命令窗口

执行 julia 命令可以直接进入交互式命令窗口: $ julia __ _ _(_)_ | Documentation: https://docs.julialang.org(_) | (_) (_) |_ _ _| |_ __ _ | Type "?" for help, "]?" for Pkg help.| | | | | | |/ _ | || |…

nginx的介绍及源码安装

文章目录前言一、nginx介绍二、nginx应用场合三、nginx的源码安装过程1.下载源码包2.安装依赖性-安装nginx-创建软连接-启动服务-关闭服务3.创建nginx服务启动脚本4.本实验---纯代码过程前言 高可用:高可用(High availability,缩写为 HA),是指系统无中断地执行其功…

win7下安装postgreSQL教程

系统环境:Windows 7 旗舰版 64位操作系统 安装版本:postgresql-9.1.4-1-windows-x64 安装步骤: 1、下载系统对应的软件版本; 2、双击“postgresql-9.1.4-1-windows-x64.exe”打开安装窗口; 3、Welcome页,…

图解操作系统

硬件结构 CPU是如何执行程序的? 图灵机的工作方式 图灵机的基本思想:用机器来模拟人们用纸笔进行数学运算的过程,还定义了由计算机的那些部分组成,程序又是如何执行的。 图灵机的基本组成如下: 有一条「纸带」&am…

allure简介

allure介绍allure是一个轻量级,灵活的,支持多语言的测试报告工具多平台的,奢华的report框架可以为dev/qa提供详尽的测试报告、测试步骤、log也可以为管理层提供high level统计报告java语言开发的,支持pytest,javaScript,PHP等可以…

C语言——动态内存管理

目录0. 思维导图:1. 为什么存在动态内存分配2. 动态内存函数介绍2.1 malloc和free2.2 calloc2.3 realloc3. 常见的动态内存错误3.1 对NULL指针的解引用操作3.2 对动态内存开辟的空间越界访问3.3 对非动态开辟内存使用free释放3.4 使用free释放一块动态开辟内存的一部…