Spark 3.0 - 12.ML GBDT 梯度提升树理论与实战

news/2024/5/20 3:09:18/文章来源:https://blog.csdn.net/BIT_666/article/details/128334331

目录

一.引言

二.GBDT 理论

1.集成学习

2.分类 & 回归问题

3.梯度提升

4.GBDT 生成

三.GBDT 实战

1.数据准备

2.构建 GBDT Pipeline

3.预测与评估

四.总结


一.引言

关于决策树前面已经介绍了常规决策树与随机森林两种类型的知识,本文主要介绍梯度提升树 Gradient Boosting Decision Tree 即常说的 GBDT,其实一种使用决策树集成的流行分类和回归方法。梯度提升算法的思想类似于随机梯度下降。该算法中模型由若干个 F(x) 即基学习器构成,每个 F(x) 都拥有一个权重 Weight,初始化时各个权重相同,之后不断地将模型计算结果与真实结果进行比较,如果出错则增加错误样本的权重并基于新权重样本,让模型朝着损失减少最快的负梯度方法进行优化。其整体可以看做是 Bossting 方法,主要思想是每一次建立模型都是在之前建立模型损失函数的梯度下降方向,即"每次沿着当前位置最陡峭,损失下降最快的方向移动"。

二.GBDT 理论

决策树相对来说很直观形象,同学们也很好理解,但是到了梯度提升树,负梯度、最小化残差等概念的出现容易找不到方向,其次为什么0-1分类问题也有梯度等等疑问也随之而来,在使用 Spark 3.0 ML 介绍梯度提升树的使用之前,我们先熟悉一些 GBDT 的基础数学概念,做到理论实践相结合。

1.集成学习

上一文随机森林就是一种集成学习的方法,其一般都有一个基学习器 Tk,对于随机森林、梯度提升树而言,基学习器 Tk 就是我们常规的 DT 决策树。针对常见的分类与回归问题,我们的问题都可以转化为下述数学语言,构造一个函数 y = f(x),训练模型使得 f(x) 尽量与真实值相同:

y = f(x)

而实际运行中,我们的模型很难做到百发百中的精准预测,往往预测值与真实值之前存在一定偏差:

y = f(x) + ResidualError

这个 residual error 就是我们常说的残差,即 y - f(x) 的差值即真实值与预测值之前的差异。实践场景下我们一般对残差进行如下度量:

A.偏差 - 与真实值分布的偏差大小,体现模型的预测能力,越小模型预测越准

B.方差 - 与真实值分布的偏差均值方差,体现模型的预测稳定性,越小模型越稳定

集成学习中,基学习器一般为简单的算法模型、例如 LR、DT,因此其单一学习器的预测能力有限,从而通过集成学习将多个基学期的组合,针对残差进行拟合,进而降低模型的偏差与方差,提高模型整体的预测能力与稳定性,达到 "N个臭皮匠 Tk,能顶一个诸葛亮" 的思想。虽然 RF 和 GBDT 都是基于树的集成学习,但是二者亦有不同,前者 Tk 可以并行生成,是典型的 bagging,而 GBDT 的 Tk 是串行生成,类似于 Adaboost。

2.分类 & 回归问题

回归问题我们在 LR 中就遇到过,下面简单复述下,给定数据集 X:

D={(x_1,y_1), (x_2,y_2),\cdots ,(x_n,y_n)}

其中 x 为 K 维特征 (K >=1):

x_n=(x_{n1},x_{n2}, \cdots,x_{nk})

其中 y 为真实输出值,分类任务对应 0-1,回归任务对应预测值,我们的目的就是构建一个模型:

F(x_n)

去尽可能的逼近每一个真实值 y。

A.分类问题损失函数:(常见的指数损失函数)

loss(y,F(x))=-y_iF(x)+log(1+exp(F(x)))

B.回归问题损失函数:(常见的 MSE 均方误差损失函数)

loss(y,F(x))=E(y - F(x))^2

3.梯度提升

Gradient Boosting Decision Tree,简单分词可以得到两个主体,分别为 Gradient Boosting 与 Decision Tree,所以我们把这两个东西搞差不多,GBDT 我们也就搞差不多了,DT 可以参考 决策树原理与实战。此时基学习器 Tk 为 DT 决策树,假设当前:

F_{k}(x)=\sum_{i=1}^{k}T_i{x}

前 K 个基学习器的预测值为 Fk(x),可以看到 GBDT 是一种加法模型,它把所有基础模型的预测值累加起来作为最终的预测值。由于 GBDT 采用串行的生成方式生成新的基学习器,所以我们将上面的公式修改为递推形式:

F_k(x)=F_{k-1}(x) + T_k(x)

在训练第 K 个 T(x) 时,我们需要最小化如下目标函数:

J=\sum_{n=1}^{N}L(y_n,F_k(x_n)) = \sum_{n=1}^{N}L(y_n, F_{k-1}(x_n)+T_k(x))

此处我们需要使用梯度下降的方法,让目标函数的取值朝着最快的下降方向前进。以 MSE 交叉熵损失函数为例:

J=\sum_{n=1}^{N}L(y_n,F_k(x_n))=\sum_{n=1}^{N} \frac{1}{2}(y_n - F_k(x_n))^2

对 F(x) 求导可得:

\sum _{n=1}^{N}\frac{\partial L(y,F(x))}{\partial F(x)}=\sum _{n=1}^{N}\frac{\partial (\frac{1}{2}y_i-F_k(x_i))^2}{\partial F_k(x_i)}=\sum_{n=1}^{N}F_k(x_i) - y_i

后面得到的结果就是我们集成学习部分提到的负残差。由随机梯度下降更新公式可知,这里可以参考简易的 牛顿法参数更新,其中 α 为学习率:

F_k(x) = F_{k-1}(x) - \alpha \cdot \frac{\partial J}{\partial F_{k-1}(x)}

后面的求导结果为负残差,所以移项可得 (此处忽略 α):

F_{k}(x) - F_{k-1}(x) = T_k(x) = -1 \cdot \frac{\partial J}{\partial F} = \sum y_n - F_{k-1}(x_n)

所以可以看到每次新增的基学习器 Tk 都用于拟合之前所有 Ti 与当前真实值之间的残差时,导数梯度下降最快,从而模型拟合效果更好。

4.GBDT 生成

GBDT 串行生成,假设我们的第一个基学习器是:

T_1(x)

此时对应残差为,第二个学习器 T2 负责拟合 T1 与 y 之间的残差:

ResidualError_1 = T_2(x) = y - T_1(x)

根据残差拟合 T2 并串行增加到 T1 后面,得到最新的 GBDT 模型:

\hat{y} = F(x) = T_1(x) + T_2(x)

依次类推,不断在新函数的基础上求得残差,并通过残差拟合新的 Tk,直到达到我们预定的精度要求或者树要求,即代表 GBDT 模型生成完毕:

\hat{y} = F(x) = \sum_{i=1}^{K}T_k(x)

实际操作中,有时还会根据上一轮的误差修改新一轮样本 X 的权重,从而使得新增的 Tk 对于之前集成学习器分类错误的样本能够拥有更好的分类结果,从而提升整体集成学习器的预测能力。

三.GBDT 实战

1.数据准备

spark.read.format 对 LIBSVM 数据进行读取加载

LabelIndexer 对预测值进行重新编码映射

featureIndexer 对特征进行离散与连续的区分

randomSplit 将数据按比例分为训练、测试数据

    val spark = SparkSession.builder//创建spark会话.master("local").appName("GradientBoostedTreeClassifierExample")//设置名称.getOrCreate() //创建会话变量// 读取文件,装载数据到spark dataframe 格式中val data = spark.read.format("libsvm").load("/Users/xudong11/sparkV3/src/main/scala/org/example/RandomForest/sample_libsvm_data.txt")// 搜索标签,添加元数据到标签列// 对整个数据集包括索引的全部标签都要适应拟合val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)// 自动识别分类特征,并对其进行索引// 设置MaxCategories以便大于4个不同值的特性被视为连续的。val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data)// 按照7:3的比例进行拆分数据,70%作为训练集,30%作为测试集val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

Tips:

样本为 libsvm 格式,特征维度 692,标签为二分类

2.构建 GBDT Pipeline

gbt 构造 GBDT 分类器

labelConverter 将上述标签转换的标签再映射回来

pipeline 将上述 Stage 拼接得到最终的 Estimator

pipeline.fit 训练模型,获取预测的 transformer

    // 建立一个决策树分类器,并设置MaxIter最大迭代次数为10val gbt = new GBTClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(10).setFeatureSubsetStrategy("auto")// 将索引标签转换回原始标签val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labelsArray(0))// 把索引和决策树链接(组合)到一个管道(工作流)之中val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter))// 载入训练集数据正式训练模型val model = pipeline.fit(trainingData)

Tips:

这里 FeatureSubsetStrategy 是属性在每个节点中计算的数目,即用作在每个树节点进行分割的候选特征数量,该数字被指定为总特征数量的分数或函数。减少这个数字会加快训练速度,但是太低的话会影响性能,这里建议使用 auto 参数让 ML 内核自动决定每个节点的属性数。

3.预测与评估

model.transform 用上一步得到的 transformer 对测试集数据预测

evaluator 计算预测样本的 Accuracy

toDebugString 获取本次训练的 GBDT 树的简介

    // 使用测试集作预测val predictions = model.transform(testData)// 选择一些样例进行显示predictions.select("predictedLabel", "label", "features").show(5)// 计算测试误差val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")val accuracy = evaluator.evaluate(predictions)println(s"Test Error = ${1.0 - accuracy}")val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel]println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}")spark.stop()

Tips:

由于篇幅长度,这里我们只展示前 4 棵树,我们的模型共拥有 10 棵树,对应问题为 2 分类问题,全部特征为 692 维度。将全部预测值与权重加权求和,再经过 sigmoid 函数即可得到对应标签类型,如果是回归问题,则不需要 sigmoid 函数。这里 GBDT 处理二分类也借鉴了 LR,通过 sigmoid 函数将分类的非线性问题转化到 y = wx + b 的线性函数。

 

四.总结

GBDT 增加学习器意在让模型的损失函数持续下降,其中最好的方式就是让损失函数在梯度方向下降,此时优化速度最快。Boosting 算法是一种继承学习方法,每一轮训练样本都是固定的,改变的是每个样本的权重,根据错误率调整样本权重,错误率越大的样本权重越大。各个预测函数只能顺序生成,因为后一个模型需要用到上一个模型的结果。通过加法模型不断减小训练产生的残差,实现数据的分类与回归。在 Gradient Boosting 中,每个新基学习器的建立都是为了使之前的模型残差往梯度方向减少。啰嗦了这么多,下面我们简单总结一下:

A.训练阶段,GBDT 的基学习器只能串行生成,但是预测阶段可以通过并行计算提高效率

B.GBDT 分类问题支持 LogLoss,回归问题支持 MSE、MAE,SPARK ML 默认为 L2 MSE。

C.Iter 参数为迭代次数,每增加1都会新增一棵树,预测的准确性也随之增加

D.适量的增加树可以提高模型准确能力,但也会带来过拟合风险,可以添加 Reg 正则化参数

E.GBDT 可以自动筛选重要特征,可以与其他模型配合使用,例如最常见的 GBDT + LR

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_236293.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微服务调用工具

微服务调用工具目录概述需求:设计思路实现思路分析1.A2.B3.C参考资料和推荐阅读Survive by day and develop by night. talk for import biz , show your perfect code,full busy,skip hardness,make a better result,wait for change,challenge Survive…

华为二面,原来是我对自动化测试的理解太肤浅了..

如何使用Python实现自动化测试 如果你入职一家新的公司,领导让你开展自动化测试,作为一个新人,你肯定会手忙脚乱,你会如何落地自动化测试呢?资深测试架构师沉醉将告诉你如何落地自动kan化测试,本次话题主要…

事业编招聘:南方科技大学附属实验学校2022年·面向应届毕业生招聘在编教师公告

南方科技大学是在中国高等教育改革发展背景下创建的一所高起点公办创新型大学,2022年2月14日,教育部等三部委公布第二轮“双一流”建设高校及建设学科名单,南方科技大学入选“双一流”建设高校名单。 南方科技大学附属实验学校,地…

大学生静态HTML网页源码 我的校园网页设计成品 学校班级网页制作模板 web课程设计 dreamweaver网页作业

🎉精彩专栏推荐 💭文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业: 【📚毕设项目精品实战案例 (10…

Jenkins持续集成项目搭建与实践——基于Python Selenium自动化测试(自由风格)

📌 博客主页: 程序员二黑 📌 专注于软件测试领域相关技术实践和思考,持续分享自动化软件测试开发干货知识! 📌 公号同名,欢迎加入我的测试交流群,我们一起交流学习! 目录…

艾美捷CpG ODN——ODN 1720 (TLRGRADE)说明书

艾美捷CpG ODN系列——ODN 1720 (TLRGRADE):具有硫代磷酸酯骨架的GpC寡脱氧核苷酸。 艾美捷CpG ODN 丨ODN 1720 (TLRGRADE)化学性质: 序列:5-tccatgagcttcctgatgct-3(小写字母表示硫代磷酸酯键)。 MW:638…

MySQL -2 指令

客户端SQL指令记录: -- 针对 数据库和针对数据表 (一)数据库 1. 查看当前所有数据库:show databases; 2. 创建数据库:create database 数据库名 DEFAULT CHARSET utf8 COLLATE utf8_general_ci; 3. 删除数据库&#…

微信公众号开发——实现用户微信网页授权流程

😊 作者: 一恍过去💖 主页: https://blog.csdn.net/zhuocailing3390🎊 社区: Java技术栈交流🎉 主题: 微信公众号开发——实现用户微信网页授权流程⏱️ 创作时间: …

哈希表及其与Java类集的关系

目录 1.哈希表的概念 2.哈希冲突 3.如何避免哈希冲突? 3.1哈希函数设计 3.2 负载因子的调节 4.解决哈希冲突 4.1闭散列 4.1.1线性探测 4.1.2二次探测 4.2开散列(哈希桶) 5.HashMap 6.HashSet 1.哈希表的概念 假设有一组数据,要让你去搜索其中的一个关键码,这种场…

嵌入式软件工程师技能树——Linux应用编程+网络编程+驱动开发+操作系统+计算机网络

文章目录Linux驱动开发1、Linux内核组成2、用户空间与内核的通讯方式有哪些?3、系统调用read/write流程4、内核态用户态的区别5、bootloader内核 根文件的关系6、BootLoader的作用7、BootLoader两个启动阶段1、汇编实现,完成依赖于CPU体系架构的设置&…

异常检测方法总结

在数据挖掘中,异常检测(英语:anomaly detection)对不匹配预期模式或数据集中其他项目的项目、事件或观测值的识别。 通常异常项目会转变成银行欺诈、结构缺陷、医疗问题、文本错误等类型的问题。异常也被称为离群值、新奇、噪声、…

[Android移动安全渗透基础教程] 易受攻击的移动应用程序

也许每个人出生的时候都以为这世界都是为他一个人而存在的,当他发现自己错的时候,他便开始长大 少走了弯路,也就错过了风景,无论如何,感谢经历 0x01 如何设置 GoatDroid (FourGoats) 1.1 简介(概述&#…

第十四届蓝桥杯集训——JavaC组第十三篇——for循环

第十四届蓝桥杯集训——JavaC组第十三篇——for循环 目录 第十四届蓝桥杯集训——JavaC组第十三篇——for循环 for循环(重点) 倒序迭代器 for循环死循环 for循环示例 暴力循环 等差数列求和公式 基础循环展开 循环控制语句 break结束 continue继续 for循环(重点) f…

MySQL主从复制太慢,怎么办?

本文分析了MySQL主从延迟的原因以及介绍了MTS方案。点击上方“后端开发技术”,选择“设为星标” ,优质资源及时送达mysql主从同步延迟原因导致备库延迟的原因主要有如下几种:通常备库所在机器的性能要比主库所在的机器性能差,执行…

nodejs092学生考勤请假管理系统vue

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.3 B/S结构 4 2.4 MySQL数据库 4 前端技术:nodejsvueelementui 前端:HTML5,CSS3、JavaScript、VUE 系统…

ChatGpt详细注册流程

ChatGpt详细注册流程ChatGpt的网址:直接点击我 点击链接后向下滑动看到TRY CHATGPT如下图所示: 点击TRY CHATGPT后会跳转如下图界面: 点击Log in(登录)如下图: 因为首次登录你肯定是没有账号的所以需要先点击红框框出的Sign up…

七个步骤覆盖 API 接口测试

接口测试作为最常用的集成测试方法的一部分,通过直接调用被测试的接口来确定系统在功能性、可靠性、安全性和性能方面是否能达到预期,有些情况是功能测试无法覆盖的,所以接口测试是非常必要的。首先需要对接口测试的基本信息做一些了解&#…

Java+Swing实现的五子棋游戏

JavaSwing实现的五子棋游戏一、系统介绍二、功能展示1.游戏展示三、系统实现1.ChessFrame .java四、其它1.其他系统实现2.获取源码一、系统介绍 五子棋游戏实现人机对战、人人对战两个模式。 二、功能展示 1.游戏展示 三、系统实现 1.ChessFrame .java package five;impor…

JDK的使用——Java开发第一步

JDK的使用——Java开发第一步 1 什么是JDK JDK是 Java 语言的软件开发工具包,是整个java开发的核心,使用Java开发第一步就是要在计算机上安装JDK。 JDK主要包含三个部分: 1 JAVA开发工具(jdk\bin) 2 基础开发库(jdk\jre\lib) 3 基础开发库…

java TCP发送数据

TCP是一种可靠的网络协议 他在通信的两端都建立了Socke对象 从而形成了两端的虚拟链路 一旦建立了虚拟链路 两端就可以通过链路通信 TCP会将通信两端分为 客户端和服务端 客户端通过Socke实现 服务端通过ServerSocke实现 那么 我们就来实现一下发送数据的方法 我们创建一个测…