Spark 3.0 - 8.ML Pipeline 之决策树原理与实战

news/2024/3/29 21:23:30/文章来源:https://blog.csdn.net/BIT_666/article/details/128045638

目录

一.引言

二.决策树基础-信息熵

三.决策树的算法基础 - ID3 算法

四.ML 中决策树的构建

1.信息增益计算

2.连续属性划分

五.ML 决策树实战

1.Libsvm 数据与加载

2.StringIndexer

3.VectorIndexer

4.构建决策树与 Pipeline

5.测试与评估

6.获取决策树

六.总结


一.引言

决策树是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评估项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于其决策分支画成图像很像一颗树的枝干,故称之为决策树。

二.决策树基础-信息熵

信息熵指的是对事件中不确定的信息度量。一个事件或属性中,其信息熵越大代表其不确定因素越大,对数据分析的计算更有益。因为熵其实用来描述一个物体或事件内部的混乱程度。在一个事件中,需要计算各个属性的不同信息熵。如果事件中包含 n 个属性,且个属性事件彼此独立、无相关性,此时可以将信息熵定义为单个属性的对数平均值:

 举个栗子 🌰:

上述描述了 14 天中不同的天气属性以及是否出门打网球:

E(tennis) = -\sum p_ilog_p_i = -(\frac{5}{14}log_2\frac{5}{14})- -(\frac{9}{14}log_2\frac{9}{14})\approx 0.94

即是否出门打网球这个属性的信息熵 Entropy 为 0.918,同理以 Humidity 湿度为属性计算 Tennis 的信息熵:

其中 High 情况下 Tennis Yes 3 次,No 4 次,Normal 情况下 Tennis Yes 6 次,No 1 次:

E (Hmd) = \frac{7}{14} \cdot (-\frac{3}{7}*1og\frac{3}{7}-\frac{4}{7}log\frac{4}{7}) + \frac{7}{14} \cdot (-\frac{1}{7}*1og\frac{1}{7}-\frac{6}{7}log\frac{6}{7}) \approx 0.788

通过计算对数平均值可以获得条件概率下不同属性的信息熵。使用下述方法可以轻松计算一个属性的信息熵:

  def calcEntropy(pArr: Array[Double]): Double = {var sum = 0DpArr.foreach(p => {sum += -1.0 * p * log2(p)})sum}def log2(p: Double): Double = {Math.log(p) / Math.log(2) // Math.log的底为e}

三.决策树的算法基础 - ID3 算法

上面介绍了信息熵,下面基于这个概念介绍如何尽可能的建立一颗最短的、最小的且最有效的决策树。ID3 算法是基于信息熵的一种经典决策树构建方法,其以信息熵的下降速度为选取测试属性的标准,即在每个节点选取还尚未用来划分的、具有最高信息增益的属性作为划分标准并不断重复这个过程,直到生成的决策树可以完美分类训练样例。其核心为信息增益:

信息增益,指的是一个时间前后发生的不同信息之间的差值,即在决策树生成过程前后不同的信息熵差值,公式可以表达为:

Gain(P_1,P_2) = E(P_1) - E(P_2)

以上面的 Tessins 与 Humidity 为例:

然后,重复上表中每个属性的信息增益计算,并选择信息增益最高的属性作为决策树中的第一个分割点。在这种情况下,outlook 产生了最高的信息增益。然后,对每个子树重复该过程。 

四.ML 中决策树的构建

1.信息增益计算

Spark ML 实现了支持使用连续和离散特征的二元和多类分类以及回归的决策树。该实现按行对数据进行分区,允许使用数百万甚至数十亿个实例进行分布式训练。决策树构建采用递归二分法方式。不断从根节点进行生成,直到决策树的需求信息增益满足一定条件为止:

Gain(P_1,P_2) = E(P_1) - \frac{P_{left}}{P_1}E(P_{left}) - \frac{P_{right}}{P_1}E(P_{right})

Left、right 为待计算属性,每增加一个分类节点,待计算属性便减少一个。

2.连续属性划分

上面介绍的 Tennis,其属性均为离散属性,实际应用中会有大量的连续性特征,解决办法就是在计算时根据需要将数据划分为若干个部分进行处理。这些被划分的若干部分在 ML 中称为 bin,即分箱的意思。每个作为分割点的节点被称为 split。决策树是一种贪婪算法,其通过从每一组可能的分割中选择最佳分割来贪婪的选择每个分割。例如一个连续特征为 {1,2,3,4,5},实际中采用二分法,此时 split 为3,划分得到 {1,2,3},{4,5} 两个 bin。

五.ML 决策树实战

1.Libsvm 数据与加载

实战前首先熟悉一种数据格式-Libsvm:

数据的第一列为标签,以上面 Tennis 为例,1 代表出去,0 代表不出去,后面的 key 代表属性的序号,value 代表该属性的具体值。下面读取实战的数据:

    val spark = SparkSession.builder                //创建spark会话.master("local")        //设置本地模式.appName("DecisionTreeClassificationExample")   //设置名称.getOrCreate()          //创建会话变量spark.sparkContext.setLogLevel("error")// 读取文件,装载数据到spark dataframe 格式中val data = spark.read.format("libsvm").load("./sample_libsvm_data.txt")

通过 .format 指定 livsvm 格式即可读取对应格式文件,解析后获得一个两列的 DataFrame,一列为 label,另一列为 features。

Tips:

细心的同学可能会发现,原始的数据为 128:x、129:y、130:z,为什么 features 里变成了 127、128、129 所有的索引的减了1,这是因为 Libsvm 数据格式从 1 开始,而我们的 features 的 Vector 内索引从 0 开始,所以需要将 libsvm 数据中的 key 都减去1,而 value 则不动。

2.StringIndexer

    val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)

字符索引器,通过遍历标签添加元数据到标签列,实现【标签 -> 序号】的映射,如果是数值型 Label,会先将 Label 转化为字符串,然后再进行索引化,如果 label 已经是字符串,则直接进行索引化。其中标签索引的顺序按照标签出现的频率来排序,出现最多的 Label 索引即为0,依次逆序排列。

上面说的可能比较绕,简单解释下该函数的意义就是将不规则的 label 处理为有序的数字序号,例如原始标签有 A、C、E 三种类型,通过 StringIndexer 会变成 0、1、2,而对应的映射关系取决于 A、C、E 的出现次数,次数最多的索引为 0。E 出现最多,所以 E 的索引为0,以此类推。

    +---+--------+-------------+| id|category|categoryIndex|+---+--------+-------------+| 0| A| 1.0|| 1| C| 2.0|| 2| E| 0.0|| 3| E| 0.0|| 4| A| 1.0|| 5| E| 0.0|+---+--------+-------------+

上面的实战数据我们转换看一下:

 看到 label 0.0 变为 1.0, 1.0 变为 0.0,我们再用 Spark Sql 看下标签数据的分布:

     val labelIndexDF = labelIndexer.transform(data)labelIndexDF.show(5)labelIndexDF.createOrReplaceTempView("LabelIndex")spark.sql("select label,count(*) from LabelIndex group by label").collect().foreach(println)

没毛病,label=1.0 的标签多,所以 1.0 被映射为 0.0。 

[0.0,43]
[1.0,57]

3.VectorIndexer

    // 自动识别分类特征,并对其进行索引val featureIndexer = new VectorIndexer().setInputCol("features") // 设置输入输出参数.setOutputCol("indexedFeatures").setMaxCategories(5) // 具有多于5个不同值的特性被视为连续特征.fit(data)

该方法主要用于自动识别离散与分类特征,提高决策树 ML 方法的分类效果。其中 MaxCategories 参数设置一个数值,如果某个特征的取值类型多于该参数,则该参数会被认定为连续特征,不作处理,反之会被认定为离散特征,并被重新编号为 0-K (K < MaxCategories)。


+-------------------------+-------------------------+
|features                 |indexedFeatures          |
+-------------------------+-------------------------+
|(3,[0,1,2],[2.0,5.0,7.0])|(3,[0,1,2],[2.0,1.0,1.0])|
|(3,[0,1,2],[3.0,5.0,9.0])|(3,[0,1,2],[3.0,1.0,2.0])|
|(3,[0,1,2],[4.0,7.0,9.0])|(3,[0,1,2],[4.0,3.0,2.0])|
|(3,[0,1,2],[2.0,4.0,9.0])|(3,[0,1,2],[2.0,0.0,2.0])|
|(3,[0,1,2],[9.0,5.0,7.0])|(3,[0,1,2],[9.0,1.0,1.0])|
|(3,[0,1,2],[2.0,5.0,9.0])|(3,[0,1,2],[2.0,1.0,2.0])|
|(3,[0,1,2],[3.0,4.0,9.0])|(3,[0,1,2],[3.0,0.0,2.0])|
|(3,[0,1,2],[8.0,4.0,9.0])|(3,[0,1,2],[8.0,0.0,2.0])|
|(3,[0,1,2],[3.0,6.0,2.0])|(3,[0,1,2],[3.0,2.0,0.0])|
|(3,[0,1,2],[5.0,9.0,2.0])|(3,[0,1,2],[5.0,4.0,0.0])|
+-------------------------+-------------------------+

上面示例中共有三个特征:

0 - [2,3,4,5,8,9] - 类别数为6,大于 MaxCategories,不执行划分

1 - [4,5,6,7,9],小于 MaxCategories,执行划分  [4,5,6,7,9] -> [0,1,2,3,4]

2 - [2,7,9],小于 MaxCategories,执行划分 [2,7,9] -> [0,1,2]

该方法主要用于离散特征与连续特征的区分,对于连续型特征,决策树的划分点划分多为 Feat >= threshold 的形式,而离散型的特征则多为 Feat in {value}。

4.构建决策树与 Pipeline

    // 按照7:3的比例进行拆分数据,70%作为训练集,30%作为测试集。val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))// 建立一个决策树分类器val dt = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxDepth(2)// 将索引标签转换回原始标签val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labelsArray(0))// 把索引和决策树链接(组合)到一个管道(工作流)之中val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))// 载入训练集数据正式训练模型val model = pipeline.fit(trainingData)

- DecisionTreeClassifier 主要属性有

Impuriry (String) : 计算信息增益的方式

maxDepth(Int) :  树的深度

maxBins(Int) : 能够分裂的数据集合数量

可以通过 dt.extractParamMap 方法获取当前模型的自定义参数与默认参数。

- IndexToString

该方法主要用于将转换后的 label 再映射回去,例如前面将 1->0 0->1 再重新反向映射回去:

    // 按顺序来,相当于映射 0->1 1->0println("=========Label Index=========")println(labelIndexer.labelsArray(0).mkString(","))

 labelsArray 为如下形式,将其 zipWithIndex 再反转即可实现 label 的反向映射。

=========Label Index=========
1.0,0.0

- pipeline

构建 Pipeline 之后即可实现 PipelineModel.fit 训练训练数据。

5.测试与评估

    // 使用测试集作预测val predictions = model.transform(testData)// 选择一些样例进行显示predictions.select("predictedLabel", "label", "features").show(5)// 计算测试误差val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")val accuracy = evaluator.evaluate(predictions)println(s"Test Error = ${(1.0 - accuracy)}")

 使用 PipelineModel.transform 进行测试,使用 label 与 predict 进行 accuracy 的指标评估:

 

6.获取决策树

    val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]println(s"Learned classification tree model:\n ${treeModel.toDebugString}")

使用 AsInstanceOf 将 Stage(2) 转化为 Dt 并调用 toDebugString 获取树的结构:

通过两层 If-Else 嵌套形式展示了一棵树,这样再来一个样本,我们可以轻易地判断其所属类别。

 

六.总结

本文根据样例数据进行了 Spark ML 决策树的 Demo 讲解,其中涉及到很多特征处理与转化的组件,可以通过样例进行熟悉,后续也会基于真实数据进行随机森林与梯度提升树的案例,加深对树的理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_39034.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于PHP+MySQL企业网络推广平台系统的设计与实现

企业网络推广平台系统具有很强的信息指导性特征,采用PHP开发企业网络推广平台系统 给web带来了全新的动态效果,具有更加灵活和方便的交互性。在Internet中实现数据检索越来越容易,可以及时、全面地收集、存储大量的企业资源信息以及进行发布、浏览、搜索相关的信息。让企业、个…

C++ Reference: Standard C++ Library reference: Containers: list: list: cend

C官网参考链接&#xff1a;https://cplusplus.com/reference/list/list/cend/ 公有成员函数 <list> std::list::cend const_iterator cend() const noexcept; 返回结束的常量迭代器 返回一个指向容器结束后元素的const_iterator。 const_iterator是指向const内容的迭代…

Spring Boot FailureAnalyzer 应用场景

Spring Boot 自定义FailureAnalyzer 今天在学习Spring Boot 源码的过程中&#xff0c;在spring.factories 文件中无意中发现了FailureAnalyzer 这个接口。由于之前没有接触过&#xff0c;今天来学习一下 FailureAnalyzer 接口的作用。 在学习FailureAnalyzer之前, 我们先看以…

TMA三均线股票期货高频交易策略的R语言实现

趋势交易策略是至今应用最广泛以及最重要的投资策略之一&#xff0c;它的研究手段种类繁多&#xff0c;所运用的分析工具也纷繁复杂&#xff0c;其特长在于捕捉市场运动的大方向。股指期货市场瞬息万变&#xff0c;结合趋势分析方法&#xff0c;量化投资策略能够得到更有效的应…

Discourse 的左侧边栏可以修改吗

在默认的 Discourse 配置中&#xff0c;我们左侧的边栏可以根据自己的要求进行修改吗&#xff1f; 解决办法 针对自己登录的用户&#xff0c;你是可以自己调整左侧边栏的配置。 单击右上角你的个人头像&#xff0c;然后选择属性。 在切换的界面中&#xff0c;选择属性。 在出…

校园论坛(Java)——环境配置篇

校园论坛&#xff08;Java&#xff09;——环境配置篇 文章目录校园论坛&#xff08;Java&#xff09;——环境配置篇1、写在前面2、新建Maven项目2.1 引入相关依赖2.2 配置Tomcat环境3、项目发布测试4、项目代码5、参考资料1、写在前面 Windows版本&#xff1a;Windows10JDK版…

Python数据库编程之关系数据库API规范

Python关系数据库API规范 对于关系数据库的访问&#xff0c;Python社区已经制定出一个标准&#xff0c;称为Python Database API Specification。Mysql&#xff0c;Oracal等特定数据库模块遵从这一规范&#xff0c;而且可以添加更多特性。 高级数据库API定义了一组用于连接数…

YOLO V3 详解

YOLO V3 论文链接&#xff1a;YOLOv3: An Incremental Improvement 主要改进 Anchor: 9个大小的anchor&#xff0c;每个尺度分配3个anchor。Backbone改为Darknet-53, 引入了残差模块。引入了FPN&#xff0c;可以进行多个尺度的训练&#xff0c;同时对于小目标的检测有了一定…

R语言生存分析可视化分析

生存分析指的是一系列用来探究所感兴趣的事件的发生的时间的统计方法。 生存分析被用于各种领域&#xff0c;例如&#xff1a; 癌症研究为患者生存时间分析&#xff0c; “事件历史分析”的社会学 在工程的“故障时间分析”。 在癌症研究中&#xff0c;典型的研究问题如下…

Linux redict 输入输出重定向 详细使用方法 文件描述符

Linux redict 重定向 Linux 重定向 在 Linux 系统中&#xff0c;我们需要输入和输出让系统与外部进行交互&#xff0c;比如在我们使用鼠标、键盘等输入设备时其实就是通过输入的方式让数据进行系统中。而系统输出一般就会打印在显示器上、刻录光盘等等。而我们要讲的重定向也…

(二)DepthAI-python相关接口:OAK Pipeline

消息快播&#xff1a;OpenCV众筹了一款ROS2机器人rae&#xff0c;开源、功能强、上手简单。来瞅瞅~ 编辑&#xff1a;OAK中国 首发&#xff1a;oakchina.cn 喜欢的话&#xff0c;请多多&#x1f44d;⭐️✍ 内容可能会不定期更新&#xff0c;官网内容都是最新的&#xff0c;请查…

Meta-learning

基本理解 meta learning翻译为元学习&#xff0c;也可以被认为为learn to learn 元学习与传统机器学习的不同在哪里&#xff1f; 元学习与传统机器学习&#xff0c; 这里举个通俗的例子&#xff0c;拿来给大家分享&#xff1f; 把训练算法类比成学生在学校学习&#xff0c;传…

【华为上机真题 2022】字符串分隔

&#x1f388; 作者&#xff1a;Linux猿 &#x1f388; 简介&#xff1a;CSDN博客专家&#x1f3c6;&#xff0c;华为云享专家&#x1f3c6;&#xff0c;Linux、C/C、云计算、物联网、面试、刷题、算法尽管咨询我&#xff0c;关注我&#xff0c;有问题私聊&#xff01; &…

[附源码]计算机毕业设计springboot环境保护宣传网站

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

后端存储实战课总结(上)

创建和更新订单 表设计 最少应该有以下几张表&#xff1a; 订单主表&#xff1a;保存订单基本信息订单商品表&#xff1a;保存订单中的商品信息订单支付表&#xff1a;保存订单支付和退款信息订单优惠表&#xff1a;保存订单的优惠信息 订单主表和字表是一对多关系&#xf…

1.1 统计学习方法的定义与分类

统计学习方法的定义与分类统计学习的概念统计学习的定义统计学习运用到的领域统计学习的步骤统计学习的分类统计学习的概念 统计学习的定义 统计学习 (Statistical Machine Learning) 是关于计算机基于数据构建概率统计模型并运用模型对数据进行预测与分析的一门学科。 以计…

第五站:操作符(终幕)(一些经典的题目)

目录 一、分析下面的代码 二、统计二进制中1的个数 解一&#xff1a;&#xff08;求出每一个二进制位&#xff0c;来统计1的个数&#xff09; 解二&#xff1a;&#xff08;利用左我们移或右移操作符和按位与&#xff09; 解三&#xff1a;&#xff08;效率最高的解法&…

【iOS】—— GET和POST以及AFNetworking框架

GET和POST以及AFNetworking框架 文章目录GET和POST以及AFNetworking框架GET和POSTGET和POST区别GETGET请求步骤GET请求代码POSTPOST请求步骤POST请求代码AFNetworking简介添加头文件GETGET方法GET方法参数GET方法代码样例POSTPOST方法第一种&#xff1a;第二种&#xff1a;先来…

C++:STL之Vector实现

vector各函数 #include<iostream> #include<vector> using namespace std;namespace lz {//模拟实现vectortemplate<class T>class vector{public:typedef T* iterator;typedef const T* const_iterator;//默认成员函数vector(); …