机器学习实践(1.1)XGBoost分类任务

news/2024/5/20 1:55:11/文章来源:https://blog.csdn.net/LMTX069/article/details/131199963

前言

XGBoost属于Boosting集成学习模型,由华盛顿大学陈天齐博士提出,因在机器学习挑战赛中大放异彩而被业界所熟知。相比越来越流行的深度神经网络,XGBoost能更好的处理表格数据,并具有更强的可解释性,还具有易于调参、输入数据不变性等优势。本文只做XGBoost分类任务的脚本实现,更多XGBoost内容请查看文末 附加——深入学习XGBoost

❤️ 本文完整脚本点此链接 百度网盘链接 获取 ❤️

结论先行
当使用 from xgboost import XGBClassifier 的模型进行训练时,使用的是sklearn中的XGBClassifier类,该方法中无需特意指定分类类别,方法自带类别数量n_classes_的计算,并根据数量指定了objective参数,简而言之:该方法会自动判别是多分类还是二分类任务,无需特殊说明。

之所以特殊强调from xgboost import XGBClassifier 是要区别于import xgboost as xgb 的调用,本人更建议使用前者。

下方代码取自 sklearn.py 的 class XGBClassifier(XGBModel, XGBClassifierBase)…

# 计算类别数量
import cupy as cp  # pylint: disable=E0401self.classes_ = cp.unique(y.values)
self.n_classes_ = len(self.classes_)......# objective参数的选择
if callable(self.objective):obj: Optional[Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]] = _objective_decorator(self.objective)# Use default value. Is it really not used ?params["objective"] = "binary:logistic"
else:obj = Noneif self.n_classes_ > 2:# Switch to using a multiclass objective in the underlying XGB instanceif params.get("objective", None) != "multi:softmax":params["objective"] = "multi:softprob"params["num_class"] = self.n_classes_

一.轻松实现多分类

1.1导入第三方库、数据集

# 导入第三方库,包括分类模型、数据集、数据集分割方法、评估方法
from xgboost import XGBClassifier  # 分类模型
from sklearn import datasets  # 数据集
from sklearn.model_selection import train_test_split  # 数据集分割方法
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score, \classification_report  # 评估方法
import xgboost as xgb# 导入sklearn的鸢尾花卉数据集,作为模型的训练和验证数据
data = datasets.load_iris()# 数据划分,按照7 3分切割数据集为训练集和验证集,其中最终4个结果依次为训练数据、验证数据、训练数据的标签、验证数据的标签
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.3,random_state=123)

sklearn的鸢尾花卉数据集共150个数据样本,73切分后,训练集105个数据样本,验证集45个数据样本。数据集中包括 样本特征data(4个特征)、样本标签target(3类标签)、标签名称target_names([‘setosa’, ‘versicolor’, ‘virginica’])、特征名称feature_names([‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’])、以及数据集位置filename(~~~\anaconda\lib\site-packages\sklearn\datasets\data\iris.csv)

数据集的 部分特征数据 如下:
在这里插入图片描述
数据集的 标签数据 及 标签名称 如下:在这里插入图片描述
数据集文件所在本地地址,其他波士顿房价等数据也在此文件夹下
在这里插入图片描述

1.2模型训练、验证

# 默认参数的模型
model = XGBClassifier()# 调参见 附加1 的文章内容
# model = XGBClassifier(booster='gbtree',
#                       n_estimators=20,  # 迭代次数
#                       learning_rate=0.1,  # 步长
#                       max_depth=5,  # 树的最大深度
#                       min_child_weight=1,  # 决定最小叶子节点样本权重和
#                       subsample=0.8,  # 每个决策树所用的子样本占总样本的比例(作用于样本)
#                       colsample_bytree=0.8,  # 建立树时对特征随机采样的比例(作用于特征)典型值:0.5-1
#                       nthread=4,
#                       seed=27,  # 指定随机种子,为了复现结果
#                       # num_class=4,  # 标签类别数
#                       # objective='multi:softmax',  # 多分类
#                       )# 模型训练
model.fit(X_train, y_train, verbose=True)# 模型对验证数据做预测 y_pred 预测结果,y_proba 预测各类别概率,y_pred 是softmax(y_proba) 的结果 
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)# 为了便于观察验证数据的预测结果,写个循环传个参
for m, n, p in zip(y_proba, y_pred, y_test):if n == p:q = '预测正确'else:q = '预测错误'print('预测概率为{0}, 预测概率为{1}, 真是结果为{2}, {3}'.format(m, n, p, q))# 准确率
accuracy = accuracy_score(y_test, y_pred)
print('Accuracy:%.2f%%' % (accuracy * 100))

打印运行结果如下
在这里插入图片描述

二.实现二分类同样简单

2.1导入数据集、训练、验证

导入数据集、训练、验证 与多分类完全一样,唯一需要改变的是将数据标签由n_classes>2转为n_classes=2

# 数据划分
X_train, X_test, y_train, y_test = ......  
# 接# 训练集和验证集的标签都改成0和1
y_train = [1 if y > 0 else 0 for y in y_train]y_test = [1 if y > 0 else 0 for y in y_test]# 计算训练数据类别
n_classes = len(set(y_train))# 接
# 默认参数的模型
model = XGBClassifier()
......

2.2再做模型训练和验证

内容与1.2完全一致,不再赘述,直接看结果
二分类任务预测概率有两个,预测概率通过sigmoid做出预测结果。多分类预测概率有多个,预测概率通过softmax做出预测结果。
在这里插入图片描述

2.3重要结论映证

"""模型参数打印"""
bst = xgb.Booster(model_file='xgb_classifier_model.model')
# print(bst.attributes())print('模型参数值-开始'.center(20, '='))
for attr_name, attr_value in bst.attributes().items():# scikit_learn 的参数逐一解析if attr_name == 'scikit_learn':import jsondict_attr = json.loads(attr_value)# 打印 模型 scikit_learn 参数for sl_name, sl_value in dict_attr.items():if sl_value is not None:print(f"{sl_name}:{sl_value}")else:print(f"{attr_name}:{attr_value}")
print('模型参数值-结束'.center(20, '='))

下图展示model = XGBClassifier()未指定参数情况下,模型的参数情况打印,其中objectiveclasses_参数映证了最前面的结论。
在这里插入图片描述

三.模型评估

评估方法更详细解释,见附加3的文章内容

def metrics_sklearn(class_num, y_valid, y_pred_, y_prob):"""模型效果评估"""# 准确率# 准确度 accuracy_score:分类正确率分数,函数返回一个分数,这个分数或是正确的比例,或是正确的个数,不考虑正例负例的问题,区别于 precision_score# 一般不直接使用准确率,主要是因为类别不平衡问题,如果大部分是negative的 而且大部分模型都很容易判别出来,那准确率都很高, 没有区分度,也没有实际意义(因为negative不是我们感兴趣的)accuracy = accuracy_score(y_valid, y_pred_)print('Accuracy:%.2f%%' % (accuracy * 100))# 精准率if class_num == 2:precision = precision_score(y_valid, y_pred_)else:precision = precision_score(y_valid, y_pred_, average='macro')print('Precision:%.2f%%' % (precision * 100))# 召回率# 召回率/查全率R recall_score:预测正确的正样本占预测正样本的比例, TPR 真正率# 在二分类任务中,召回率表示被分为正例的个数占所有正例个数的比例;如果是多分类的话,就是每一类的平均召回率if class_num == 2:recall = recall_score(y_valid, y_pred_)else:recall = recall_score(y_valid, y_pred_, average='macro')print('Recall:%.2f%%' % (recall * 100))# F1值if class_num == 2:f1 = f1_score(y_valid, y_pred_)else:f1 = f1_score(y_valid, y_pred_, average='macro')print('F1:%.2f%%' % (f1 * 100))# auc曲线下面积# 曲线下面积 roc_auc_score 计算AUC的值,即输出的AUC(二分类任务中 auc 和 roc_auc_score 数值相等)# 计算auc,auc就是曲线roc下面积,这个数值越高,则分类器越优秀。这个曲线roc所在坐标轴的横轴是FPR,纵轴是TPR。# 真正率:TPR = TP/P = TP/(TP+FN)、假正率:FPR = FP/N = FP/(FP+TN)# auc:不支持多分类任务 计算ROC曲线下的面积# 二分类问题直接用预测值与标签值计算:auc = roc_auc_score(Y_test, Y_pred)# 多分类问题概率分数 y_prob:auc = roc_auc_score(Y_test, Y_pred_prob, multi_class='ovo') 其中multi_class必选if class_num == 2:auc = roc_auc_score(y_valid, y_pred_)else:auc = roc_auc_score(y_valid, y_prob, multi_class='ovo')# auc = roc_auc_score(y_valid, y_prob, multi_class='ovr')print('AUC:%.2f%%' % (auc * 100))# 评估效果报告print(classification_report(y_test, y_pred, target_names=['0:setosa', '1:versicolor', '2:virginica']))"""模型效果评估"""
n_classes = len(set(y_train))
metrics_sklearn(n_classes, y_test, y_pred, y_proba)

多分类模型效果评估结果入下图所示
在这里插入图片描述
本文完整脚本可通过百度网盘链接 获取

附加——深入学习XGBoost

附加1.模型调参、训练、保存、评估和预测

见《XGBoost模型调参、训练、评估、保存和预测》 ,包含模型脚本文件

附加2.算法原理

见《XGBoost算法原理及基础知识》 ,包括集成学习方法,XGBoost模型、目标函数、算法,公式推导等

附加3.分类任务的评估指标值详解

见《分类任务评估1——推导sklearn分类任务评估指标》,其中包含了详细的推理过程;
见《分类任务评估2——推导ROC曲线、P-R曲线和K-S曲线》,其中包含ROC曲线、P-R曲线和K-S曲线的推导与绘制;

附加4.模型中树的绘制和模型理解

见《Graphviz绘制模型树1——软件配置与XGBoost树的绘制》,包含Graphviz软件的安装和配置,以及to_graphviz()和plot_trees()两个画图函数的部分使用细节;
见《Graphviz绘制模型树2——XGBoost模型的可解释性》,从模型中的树着手解释XGBoost模型,并用EXCEL构建出模型。

❤️ 机器学习内容持续更新中… ❤️


声明:本文所载信息不保证准确性和完整性。文中所述内容和意见仅供参考,不构成实际商业建议,可收藏可转发但请勿转载,如有雷同纯属巧合。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_506526.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

hard fault on thread: mqtt0解决办法

rt thread版本4.1.0 使用paho mqtt软件包 运行一段时间后出现 psr: 0x21000000 r00: 0x5036fc8f r01: 0x5036fc88 r02: 0x00000000 r03: 0x5036fc8f r04: 0x00000007 r05: 0x00000063 r06: 0x00005f70 r07: 0x2001f1d8 r08: 0xdeadbeef r09: 0xdeadbeef r10: 0xdeadbeef r11…

关于Java SSM框架的面试题

一、Spring面试题 1、Spring 在ssm中起什么作用? Spring:轻量级框架作用:Bean工厂,用来管理Bean的生命周期和框架集成。两大核心:1、IOC/DI(控制反转/依赖注入) :把dao依赖注入到service层,se…

28.vite

目录 1 一些概念 1.1 单页面应用程序SPA 1.2 vite 2 初始化vite项目 3 项目中的文件 1 一些概念 1.1 单页面应用程序SPA 单页面应用程序是只有一个页面的前端,切换页面通过前端路由来切换 特点如下 实现了前后端分离,后端仅出接口&#…

动态规划III (买股票-121、122、123、188、309)

CP121 买股票的最佳时机 题目描述: 给定一个数组 prices ,它的第 i 个元素 prices[i] 表示一支给定股票第 i 天的价格。你只能选择 某一天 买入这只股票,并选择在 未来的某一个不同的日子 卖出该股票。设计一个算法来计算你所能获取的最大利…

YOLOv5-7.0添加解耦头

Decoupled Head Decoupled Head是由YOLOX提出的用来替代YOLO Head,可以用来提升目标检测的精度。那么为什么解耦头可以提升检测效果呢? 在阅读YOLOX论文时,找到了两篇引用的论文,并加以阅读。 第一篇文献是Song等人在CVPR2020发表…

【59天|503.下一个更大元素II ● 42. 接雨水】

503.下一个更大元素II class Solution { public:vector<int> nextGreaterElements(vector<int>& nums) {stack<int> st;int n nums.size();vector<int> res (n, -1);for(int i0; i<2*n;i){while(!st.empty()&&nums[i%n]>nums[st.t…

随机的乐趣和游戏

1、猜数字游戏 #GuessingGame.py import random the_number random.randint(1, 10) print("计算机已经在1到10之间随机生成了一个数字&#xff0c;") guess int(input("请你猜猜是哪一个数字: ")) while guess ! the_number:if guess > the_number:p…

PHP设计模式21-工厂模式的讲解及应用

文章目录 前言基础知识简单工厂模式工厂方法模式抽象工厂模式 详解工厂模式普通的实现更加优雅的实现 总结 前言 本文已收录于PHP全栈系列专栏&#xff1a;PHP快速入门与实战 学会好设计模式&#xff0c;能够对我们的技术水平得到非常大的提升。同时也会让我们的代码写的非常…

淘宝详情页分发推荐算法总结:用户即时兴趣强化

转子&#xff1a;https://juejin.cn/post/6992169847207493639 商品详情页是手淘内流量最大的模块之一&#xff0c;它加载了数十亿级商品的详细信息&#xff0c;是用户整个决策过程必不可少的一环。这个区块不仅要承接用户对当前商品充分感知的诉求&#xff0c;同时也要能肩负起…

Spark大数据处理学习笔记1.5 掌握Scala内建控制结构

文章目录 一、学习目标二、条件表达式&#xff08;一&#xff09;语法格式&#xff08;二&#xff09;执行情况&#xff08;三&#xff09;案例演示任务1、根据输入值的不同进行判断任务2、编写Scala程序&#xff0c;判断奇偶性 三、块表达式&#xff08;一&#xff09;语法格式…

Redis入门 - Lua脚本

原文首更地址&#xff0c;阅读效果更佳&#xff01; Redis入门 - Lua脚本 | CoderMast编程桅杆https://www.codermast.com/database/redis/redis-scription.html Redis 脚本使用 Lua 解释器来执行脚本。 Redis 2.6 版本通过内嵌支持 Lua 环境。执行脚本的常用命令为 EVAL。 …

不要把异常当做业务逻辑,这性能可能你无法承受

一&#xff1a;背景 1. 讲故事 在项目中摸爬滚打几年&#xff0c;应该或多或少的见过有人把异常当做业务逻辑处理的情况(┬&#xff3f;┬)&#xff0c;比如说判断一个数字是否为整数,就想当然的用try catch包起来&#xff0c;再进行 int.Parse&#xff0c;如果抛异常就说明不…

Unity入门8——音效系统

一、音频文件参数面板 Force To Mono&#xff1a;多声道转单声道 Normalize&#xff1a;强制为单声道时&#xff0c;混合过程中被标准化 Load In Background&#xff1a;后台加载&#xff0c;不阻塞主线程&#xff0c;适合大音效 Ambisonic&#xff1a;立体混响声 非常适合 36…

JUC并发编程初学

什么是JUC进程和线程回顾Lock锁生产者和消费者8锁的线程集合类不安全CallableCountDownLatch、CyclicBarrier、Semaphore读写锁阻塞队列线程池四大函数式接口Stream流式计算分支合并异步回调JMMvolatile深入单例模式深入理解CAS原子引用可重入锁、公平锁非公平锁、自旋锁、死锁…

使用单元测试框架unittest进行有效测试

一、介绍 在软件开发中&#xff0c;单元测试是一种测试方法&#xff0c;它用于检查单个软件组件&#xff08;例如函数或方法&#xff09;的正确性。Python 提供了一个内置的单元测试库&#xff0c;名为 unittest&#xff0c;可以用来编写测试代码&#xff0c;然后运行测试&…

MyCat总结

目录 什么是mycat 核心概念 逻辑库 逻辑表 分片节点 数据库主机 用户 mycat原理 目录结构 配置文件 读写分离 搭建读写分离 搭建主从复制&#xff1a; 搭建读写分离&#xff1a; 分片技术 垂直拆分 实现分库&#xff1a; 水平拆分 实现分库&#xff1a; ER表 全局表 分…

大数据之路书摘:走近大数据——从阿里巴巴学习大数据系统体系架构

文章目录 1.数据采集层2.数据计算层3.数据服务层4.数据应用层 在大数据时代&#xff0c;人们比以往任何时候更能收集到更丰富的数据。但是如果不能对这些数据进行有序、有结构地分类组织和存储&#xff0c;如果不能有效利用并发掘它&#xff0c;继而产生价值&#xff0c;那么它…

shell索引数组变量-定义获取拼接删除

目录 介绍数组的定义演示 数组的获取数组的拼接演示&#xff1a; 数组的删除 介绍 Shell 支持数组&#xff08;Array&#xff09;, 数组是若干数据的集合&#xff0c;其中的每一份数据都称为数组的元素。 &#xff08; 注意Bash Shell 只支持一维数组&#xff0c;不支持多维数组…

基于Dlib的疲劳检测系统

需要源码的朋友可以私信我 基于Dlib的疲劳检测系统 1、设计背景及要求2、系统分析3、系统设计3.1功能结构图3.2基于EAR、MAR和HPE算法的疲劳检测3.2.1基于EAR算法的眨眼检测3.2.2基于MAR算法的哈欠检测3.3.3基于HPE算法的点头检测 4、系统实现与调试4.1初步实现4.2具体实现过程…

第五节 利用Ogre 2.3实现雨,雪,爆炸,飞机喷气尾焰等粒子效果

本节主要学习如何使用Ogre2.3加载粒子效果。为了学习方便&#xff0c;直接将官方粒子模块Sample_ParticleFX单独拿出来编译&#xff0c;学习如何实现粒子效果。 一. 前提须知 如果参考官方示例建议用最新版的Ogre 2.3.1。否则找不到有粒子效果的示例。不要用官网Ogre2.3 scri…