深度学习推荐系统(七)NFM模型及其在Criteo数据集上的应用

news/2024/5/21 0:35:21/文章来源:https://blog.csdn.net/qq_44665283/article/details/132747687

深度学习推荐系统(七)NFM模型及其在Criteo数据集上的应用

1 NFM模型原理及其实现

1.1 NFM模型原理

无论是 FM,还是其改进模型FFM,归根结底是⼀个⼆阶特征交叉的模型。受组合爆炸问题的困扰,FM 几乎不可能扩展到三阶以上,这就不可避免地限制了FM模型的表达能力。

新加坡国立大学学者利用神经网络的非线性和强表达能力来改进一下FM模型,得到一个增强版的FM模型,即NFM模型。

如下图,在数学形式上,NFM 模型的主要思路是用⼀个表达能力更强的函数替代原FM中二阶隐向量内积的部分。

在这里插入图片描述

这个表达能力更强的函数就是神经网络,因为神经网络理论上可以拟合任何复杂能力的函数, 所以作者把这个f(x)换成了一个神经网络,当然不是一个简单的DNN, 而是依然底层考虑了交叉,然后高层使用的DNN网络, 这个也就是NFM网络。

1.1.1 NFM的深度网络部分模型结构图

  • NFM 网络架构的特点非常明显,就是在 Embedding 层和多层神经网络之间加入特征交叉池化层(Bi-Interaction Pooling Layer)

  • 所示的 NFM架构图省略了其⼀阶部分。如果把 NFM的⼀阶部分视为⼀个线性模型,那么NFM的架构也可以视为Wide&Deep模型的进化。相比原始的 Wide&Deep 模型,NFM 模型对其 Deep 部分加入了特征交叉池化层,加强了特征交叉。

在这里插入图片描述

1.1.2 特征交叉池化层

在这里插入图片描述

  • 在进行两两Embedding向量的元素积操作后,对交叉特征向量取和,得到池化层的输出向量。

  • 再把该向量输入上层的多层全连接神经网络(DNN),进行进⼀步的交叉。

1.2 NFM模型的实现

NFM模型的实现在于特征交叉池化层,对原始的池化层公式进行化简:

在这里插入图片描述

import torch.nn as nn
import torch.nn.functional as F
import torchclass Dnn(nn.Module):"""Dnn 网络"""def __init__(self, hidden_units, dropout=0.):"""hidden_units: 列表, 每个元素表示每一层的神经单元个数, 、比如[256, 128, 64], 两层网络, 第一层神经单元128, 第二层64, 第一个维度是输入维度dropout: 失活率"""super(Dnn, self).__init__()self.dnn_network = nn.ModuleList([nn.Linear(layer[0], layer[1]) for layer in list(zip(hidden_units[:-1], hidden_units[1:]))])self.dropout = nn.Dropout(p=dropout)def forward(self, x):for linear in self.dnn_network:x = linear(x)x = F.relu(x)x = self.dropout(x)return xclass NFM(nn.Module):def __init__(self, feature_info, hidden_units, embed_dim=8):"""DeepCrossing:feature_info: 特征信息(数值特征, 类别特征, 类别特征embedding映射)hidden_units: 列表, 隐藏单元dropout: Dropout层的失活比例embed_dim: embedding维度"""super(NFM, self).__init__()self.dense_features, self.sparse_features, self.sparse_features_map = feature_info# embedding层, 这里需要一个列表的形式, 因为每个类别特征都需要embeddingself.embed_layers = nn.ModuleDict({'embed_' + str(key): nn.Embedding(num_embeddings=val, embedding_dim=embed_dim)for key, val in self.sparse_features_map.items()})# 注意 这里的总维度  = 数值型特征的维度 + 离散型变量每个特征要embedding的维度dim_sum = len(self.dense_features) + embed_dimhidden_units.insert(0, dim_sum)# bnself.bn = nn.BatchNorm1d(dim_sum)# dnn网络self.dnn_network = Dnn(hidden_units)# dnn的线性层self.dnn_final_linear = nn.Linear(hidden_units[-1], 1)def forward(self, x):# 1、先把输入向量x分成两部分处理、因为数值型和类别型的处理方式不一样dense_input, sparse_inputs = x[:, :len(self.dense_features)], x[:, len(self.dense_features):]# 2、转换为long形sparse_inputs = sparse_inputs.long()# 2、不同的类别特征分别embedding  [(batch_size, embed_dim)]sparse_embeds = [self.embed_layers['embed_' + key](sparse_inputs[:, i]) for key, i inzip(self.sparse_features_map.keys(), range(sparse_inputs.shape[1]))]# 3、embedding进行堆叠sparse_embeds = torch.stack(sparse_embeds) # (离散特征数, batch_size, embed_dim)sparse_embeds = sparse_embeds.permute((1,0,2))  # (batch_size, 离散特征数, embed_dim)# 这里得到embedding向量 sparse_embeds的shape为(batch_size, 离散特征数, embed_dim)# 然后就进行特征交叉层,按照特征交叉池化层化简后的公式  其代码如下# 注意:# 公式中的x_i乘以v_i就是 embedding后的sparse_embeds# 通过设置dim=1,把dim=1压缩(行的相同位置相加、去掉dim=1),即进行了特征交叉embed_cross = 1 / 2 * (torch.pow(torch.sum(sparse_embeds, dim=1), 2) - torch.sum(torch.pow(sparse_embeds, 2), dim=1))  # (batch_size, embed_dim)# 4、数值型和类别型特征进行拼接  (batch_size, embed_dim + dense_input维度 )x = torch.cat([embed_cross, dense_input], dim=-1)x = self.bn(x)# Dnn部分,使用全部特征dnn_out = self.dnn_final_linear(self.dnn_network(x))# outoutputs = torch.sigmoid(dnn_out)return outputsif __name__ == '__main__':x = torch.rand(size=(2, 5), dtype=torch.float32)feature_info = [['I1', 'I2'],  # 连续性特征['C1', 'C2', 'C3'],  # 离散型特征{'C1': 20,'C2': 20,'C3': 20}]# 建立模型hidden_units = [128, 64, 32]net = NFM(feature_info, hidden_units)print(net)print(net(x))
NFM((embed_layers): ModuleDict((embed_C1): Embedding(20, 8)(embed_C2): Embedding(20, 8)(embed_C3): Embedding(20, 8))(bn): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(dnn_network): Dnn((dnn_network): ModuleList((0): Linear(in_features=10, out_features=128, bias=True)(1): Linear(in_features=128, out_features=64, bias=True)(2): Linear(in_features=64, out_features=32, bias=True))(dropout): Dropout(p=0.0, inplace=False))(dnn_final_linear): Linear(in_features=32, out_features=1, bias=True)
)
tensor([[0.4627],[0.4660]], grad_fn=<SigmoidBackward0>)

2 NFM模型在Criteo数据集上的应用

数据的预处理可以参考

深度学习推荐系统(二)Deep Crossing及其在Criteo数据集上的应用

2.1 准备训练数据

import pandas as pdimport torch
from torch.utils.data import TensorDataset, Dataset, DataLoaderimport torch.nn as nn
from sklearn.metrics import auc, roc_auc_score, roc_curveimport warnings
warnings.filterwarnings('ignore')
# 封装为函数
def prepared_data(file_path):# 读入训练集,验证集和测试集train_set = pd.read_csv(file_path + 'train_set.csv')val_set = pd.read_csv(file_path + 'val_set.csv')test_set = pd.read_csv(file_path + 'test.csv')# 这里需要把特征分成数值型和离散型# 因为后面的模型里面离散型的特征需要embedding, 而数值型的特征直接进入了stacking层, 处理方式会不一样data_df = pd.concat((train_set, val_set, test_set))# 数值型特征直接放入stacking层dense_features = ['I' + str(i) for i in range(1, 14)]# 离散型特征需要需要进行embedding处理sparse_features = ['C' + str(i) for i in range(1, 27)]# 定义一个稀疏特征的embedding映射, 字典{key: value},# key表示每个稀疏特征, value表示数据集data_df对应列的不同取值个数, 作为embedding输入维度sparse_feas_map = {}for key in sparse_features:sparse_feas_map[key] = data_df[key].nunique()feature_info = [dense_features, sparse_features, sparse_feas_map]  # 这里把特征信息进行封装, 建立模型的时候作为参数传入# 把数据构建成数据管道dl_train_dataset = TensorDataset(# 特征信息torch.tensor(train_set.drop(columns='Label').values).float(),# 标签信息torch.tensor(train_set['Label'].values).float())dl_val_dataset = TensorDataset(# 特征信息torch.tensor(val_set.drop(columns='Label').values).float(),# 标签信息torch.tensor(val_set['Label'].values).float())dl_train = DataLoader(dl_train_dataset, shuffle=True, batch_size=16)dl_vaild = DataLoader(dl_val_dataset, shuffle=True, batch_size=16)return feature_info,dl_train,dl_vaild,test_set
file_path = './preprocessed_data/'feature_info,dl_train,dl_vaild,test_set = prepared_data(file_path)

2.2 建立NFM模型

from _01_nfm import NFMhidden_units = [128, 64, 32]
net = NFM(feature_info, hidden_units)
# 测试一下模型
for feature, label in iter(dl_train):out = net(feature)print(feature.shape)print(out.shape)print(out)break

3.3 模型的训练

from AnimatorClass import Animator
from TimerClass import Timer# 模型的相关设置
def metric_func(y_pred, y_true):pred = y_pred.datay = y_true.datareturn roc_auc_score(y, pred)def try_gpu(i=0):if torch.cuda.device_count() >= i + 1:return torch.device(f'cuda:{i}')return torch.device('cpu')def train_ch(net, dl_train, dl_vaild, num_epochs, lr, device):"""⽤GPU训练模型"""print('training on', device)net.to(device)# 二值交叉熵损失loss_func = nn.BCELoss()optimizer = torch.optim.Adam(params=net.parameters(), lr=lr)animator = Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train auc', 'val loss', 'val auc'],figsize=(8.0, 6.0))timer, num_batches = Timer(), len(dl_train)log_step_freq = 10for epoch in range(1, num_epochs + 1):# 训练阶段net.train()loss_sum = 0.0metric_sum = 0.0for step, (features, labels) in enumerate(dl_train, 1):timer.start()# 梯度清零optimizer.zero_grad()# 正向传播predictions = net(features)loss = loss_func(predictions, labels.unsqueeze(1) )try:          # 这里就是如果当前批次里面的y只有一个类别, 跳过去metric = metric_func(predictions, labels)except ValueError:pass# 反向传播求梯度loss.backward()optimizer.step()timer.stop()# 打印batch级别日志loss_sum += loss.item()metric_sum += metric.item()if step % log_step_freq == 0:animator.add(epoch + step / num_batches,(loss_sum/step, metric_sum/step, None, None))# 验证阶段net.eval()val_loss_sum = 0.0val_metric_sum = 0.0for val_step, (features, labels) in enumerate(dl_vaild, 1):with torch.no_grad():predictions = net(features)val_loss = loss_func(predictions, labels.unsqueeze(1))try:val_metric = metric_func(predictions, labels)except ValueError:passval_loss_sum += val_loss.item()val_metric_sum += val_metric.item()if val_step % log_step_freq == 0:animator.add(epoch + val_step / num_batches, (None,None,val_loss_sum / val_step , val_metric_sum / val_step))print(f'final: loss {loss_sum/len(dl_train):.3f}, auc {metric_sum/len(dl_train):.3f},'f' val loss {val_loss_sum/len(dl_vaild):.3f}, val auc {val_metric_sum/len(dl_vaild):.3f}')print(f'{num_batches * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')
lr, num_epochs = 0.001, 10
train_ch(net, dl_train, dl_vaild, num_epochs, lr, try_gpu())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_168206.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

服务器给前端实时推送数据轻量化解决方案eventSource+Springboot

一、前端代码 body代码 <div id"result"></div>js代码 $(function(){if(typeof(EventSource) ! "undefined"){var source new EventSource("/demo/getTime");source.onmessage function(event) {console.log(event.data);$(&qu…

Kafka3.0.0版本——消费者(消费者组案例)

目录 一、消费者组案例1.1、案例需求1.2、案例代码1.2.1、消费者1代码1.2.2、消费者2代码1.2.3、消费者3代码1.2.4、生产者代码 1.3、测试 一、消费者组案例 1.1、案例需求 测试同一个主题的分区数据&#xff0c;只能由一个消费者组中的一个消费。如下图所示&#xff1a; 1…

ubuntu上ffmpeg使用framebuffer显示video

这个主题是想验证使用fbdev(Linux framebuffer device&#xff09;&#xff0c;将video直接显示到Linux framebuffer上&#xff0c;在FFmpeg中对应的FFOutputFormat 就是ff_fbdev_muxer。 const FFOutputFormat ff_fbdev_muxer {.p.name "fbdev",.p.long_…

OmniGraffle Pro for Mac 中文正式版(附注册码) 苹果电脑 思维导图软件

OmniGraffle Pro是OmniGraffle的高级版本&#xff0c;它提供了更多的功能和工具&#xff0c;可以帮助用户创建更为复杂和高级的图表和流程图。OmniGraffle Pro支持自定义形状、图形、线条和箭头等&#xff0c;可以让用户创建出更加精细的图表。此外&#xff0c;OmniGraffle Pro…

Uniapp中使用uQRCode二维码跳转小程序页面

下载插件 uQRCode官网地址 引入插件 文件如下 //--------------------------------------------------------------------- // github https://github.com/Sansnn/uQRCode //---------------------------------------------------------------------let uQRCode = {};(functio…

springCloud-LoadBalancer负载均衡

接上个博客springcloud-Eureka。 Eureka主要是如何通过eureka服务器进行服务注册与发现&#xff0c;也有简单的负载均衡&#xff0c;实际上它其中的负载均衡就是靠LoadBalancer实现的。 2020年前SpringCloud是采用Ribbon作为负载均衡实现&#xff0c;但是在2020后采用了LoadBal…

【2023高教社杯数学建模国赛】ABCD题 问题分析、模型建立、参考文献及实现代码

【2023高教社杯数学建模国赛】ABCD题 问题分析、模型建立、参考文献及实现代码 1 比赛时间 北京时间&#xff1a;2023年9月7日 18:00-2023年9月10日20:00 2 思路内容 可以参考我提供的历史竞赛信息内容&#xff0c;最新更新我会发布在博客和知乎上&#xff0c;请关注我获得最…

大数据技术之Hadoop:HDFS存储原理篇(五)

目录 一、原理介绍 1.1 Block块 1.2 副本机制 二、fsck命令 2.1 设置默认副本数量 2.2 临时设置文件副本大小 2.3 fsck命令检查文件的副本数 2.4 block块大小的配置 三、NameNode元数据 3.1 NameNode作用 3.2 edits文件 3.3 FSImage文件 3.4 元素据合并控制参数 …

学习笔记|回顾(1-12节课)|应用模块化的编程|添加函数头|静态变量static|STC32G单片机视频开发教程(冲哥)|阶段小结:应用模块化的编程(上)

文章目录 1.回顾(1-12节课)2.应用模块化的编程(.c .h)Tips:添加函数头创建程序文件三步引脚定义都在.h文件函数定义三步bdata位寻址变量的使用 3.工程文件编写静态变量static的使用完整程序为&#xff1a;demo.c&#xff1a;seg_led.c:seg_led.h: 1.回顾(1-12节课) 一、认识单…

PaddleX:一站式、全流程、高效率的飞桨AI套件

随着ChatGPT引领的AI破圈&#xff0c;各行各业掀起了AI落地的潮流&#xff0c;从智能客服、智能写作、智能监控&#xff0c;到智能医疗、智能家居、智能金融、智能农业&#xff0c;谁能快速将AI与传统业务相结合&#xff0c;谁就将成为企业数字化和智能化变革的优胜者。然而&am…

C高级第2天

写一个1.sh脚本&#xff0c;将以下内容放到脚本中&#xff1a; 在家目录下创建目录文件&#xff0c;dir 在dir下创建dir1和dir2 把当前目录下的所有文件拷贝到dir1中&#xff0c; 把当前目录下的所有脚本文件拷贝到dir2中 把dir2打包并压缩为dir2.tar.xz 再把dir2.tar.xz…

Vue2+Vue3基础入门到实战项目(七)——智慧商城项目

Vue 核心技术与实战 智慧商城 接口文档&#xff1a;https://apifox.com/apidoc/shared-12ab6b18-adc2-444c-ad11-0e60f5693f66/doc-2221080 演示地址&#xff1a;http://cba.itlike.com/public/mweb/#/ 01. 项目功能演示 1.明确功能模块 启动准备好的代码&#xff0c;演示…

【2023最新版】DataGrip使用MySQL教程

目录 一、安装MySQL 二、安装DataGrip 三、DataGrip使用MySQL 1. 新建项目 2. DataGrip连接MySQL 下载驱动文件 填写root密码 测试 成功 3. DataGrip操作MySQL 四、MySQL常用命令 1. 登录 2. 帮助 3. 查询所有数据库 一、安装MySQL MySQL是一种开源的关系型数据库…

leetcode872. 叶子相似的树(java)

叶子相似的树 题目描述递归 题目描述 难度 - 简单 leetcode - 872. 叶子相似的树 请考虑一棵二叉树上所有的叶子&#xff0c;这些叶子的值按从左到右的顺序排列形成一个 叶值序列 。 举个例子&#xff0c;如上图所示&#xff0c;给定一棵叶值序列为 (6, 7, 4, 9, 8) 的树。 如果…

leetcode 143. 重排链表

2023.9.5 先将链表中的节点存储到数组中&#xff0c;再利用双指针重新构造符合条件的链表。代码如下&#xff1a; /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNod…

echarts条形图实现颜色渐变

eCharts——柱状图中的柱体颜色渐变_echarts 柱状图渐变_小美同学的博客-CSDN博客 【Echarts】柱状图渐变两种实现方式_echarts柱状图渐变_芳草萋萋鹦鹉洲哦的博客-CSDN博客

lvm + raid(逻辑磁盘+阵列)创建删除恢复 for linux

本教程适用于linux lvm为逻辑磁盘&#xff0c;raid为阵列&#xff0c;两种技术可以单独使用也可以搭配使用 2023.9.3更新 前三节是操作命令和基础知识&#xff0c;后面是实操。 一、存储硬件查看相关命令 硬盘分区相关操作在后面用的到&#xff0c;可以先略过&#xff0c;有需…

【C++二叉树】进阶OJ题

【C二叉树】进阶OJ题 目录 【C二叉树】进阶OJ题1.二叉树的层序遍历II示例代码解题思路 2.二叉搜索树与双向链表示例代码解题思路 3.从前序与中序遍历序列构造二叉树示例代码解题思路 4.从中序与后序遍历序列构造二叉树示例代码解题思路 5.二叉树的前序遍历&#xff08;非递归迭…

【实战】React17+React Hook+TS4 最佳实践,仿 Jira 企业级项目(总结展望篇)

文章目录 一、项目起航&#xff1a;项目初始化与配置二、React 与 Hook 应用&#xff1a;实现项目列表三、TS 应用&#xff1a;JS神助攻 - 强类型四、JWT、用户认证与异步请求五、CSS 其实很简单 - 用 CSS-in-JS 添加样式六、用户体验优化 - 加载中和错误状态处理七、Hook&…

[keil] uv编译分析

假设Keil安装路径: C:\Keil_v5\ 假设工程在 d:\HELLO , 工程Targets名:Simulator [在Manage Project Items中可修改] 如下指令为:Build(F7) C:\Keil_v5\UV4\UV4.exe -b d:\HELLO\Hello.uvproj -j0 -t Simulator -o d:\HELLO\uv4.log 如下指令为:Rebuild(CtrlAltF7) C:\Kei…