机器学习笔记 - sklearn决策树(kaggle 实战 Titanic 入门)

news/2024/5/8 11:57:50/文章来源:https://blog.csdn.net/haobowen/article/details/127271741

Kaggle - Titanic

前言

这是 Kaggle 上非常典型的一道入门题,可以用很多机器学习或者深度学习甚至是一些“奇淫技巧”的方法来解决。因为我是一个初学者,所以我希望在尽可能提高正确率的情况下,用更简单的方法。如果这也能帮助到你,那将是我莫大的荣幸。

决策树(分类)

简介

在机器学习中,决策树是一个预测模型,它用一种树状结构表示对象属性和对象类别之间的一种映射。决策树中的非叶子节点表示对象属性的判断条件,其分支表示符合节点条件的所有对象,决策树的叶子节点表示对象所属的类别。

决策树可以转化为一系列规则(Rule),从而构成一个规律集(Rule Set),这样的规则很容易被人们理解和运用。

这里不证明数学相关的知识,只能简单说明原理和应用。如果需要找严格证明就请找其他更好的教程,对此我很抱歉。

一、决策树的构造过程

决策树的创建从根节点开始,需要确定一个属性,根据不同记录在该属性上的取值,对所有记录进行划分。接下来,对每个分支重复这个过程,即对每个分支,选择另外一个未参与树的创建的属性,继续对样本进行划分,一直到某个分支上的样本都属于同一类(或者隶属该路径的样本大部分属于同一类)。

属性的选择也称为特征选择。特征选择的目的,是使用分类后的数据集比较纯,即数据(子)集里主要是某个类别的样本。因为决策树的目标就是要把数据集按对应的类别标签进行分类。理想的情况是,通过特征的选择,能够把不同类别的数据集贴上对应的类别标签。为了衡量一个数据集的纯度,需要引入数据纯度函数。

其中一个应用广泛的度量函数,是信息增益。信息熵表示的是不确定性。数据非均匀分布时,不确定程度最大。当选择某个特征对数据集进行分类时,分类后的数据集的信息熵会比分类前的小,其差值表示为信息增益(信息的减少值)。信息增益可以衡量某个特征对分类结果的影响大小。

那我们应该如何选择参数呢?

通常就使用基尼系数,数据维度很大,噪音很大时使用基尼系数。维度低,数据比较清晰的时候,信息熵和基尼系数没区别,当决策树的拟合程度不够的时候,使用信息熵。两个都试试,不好就换另外一个

二、决策树的剪枝

在决策树建立的过程中,很容易出现过拟合的现象。模型是在训练样本上训练出来的。过拟合,是指模型非常逼近训练样本,在训练样本上预测的正确率很高,但是对测试样本的预测正确率不高,效果并不好,也就是模型的泛化能力差。当把模型应用到新数据上的时候,其预测效果不好。

决策树的过拟合现象,可以通过剪枝进行一定的修复。决策树算法一般都需要经过两个阶段来进行构造,即树的生长阶段和剪枝阶段。剪枝分为预先剪枝和后剪枝两种情况。预先剪枝,指的是在决策树构造过程中,使用一定条件加以限制,在产生完全拟合的决策树之前就停止生长。预先剪枝的判断方法有很多,比如信息增益小于一定阈值的时候,通过剪枝使决策树停止生长。

后剪枝是在决策树构造完成之后,也就是所有的训练样本都可以用决策树划分到不同子类后,按照自底向上的方向修建决策树。后剪枝有两种方式:一种是用新的叶子节点替换子树,该节点的预测类由子树数据集中的多数类决定;另一种是用子树中最常使用的分支代替子树。后剪枝一般能够产生更好的效果,因为预先剪枝可能过早终止决策树构造过程。但是需要注意的是,后剪枝在子树被剪掉后,原来构造决策树的一部分计算就浪费了。

至此,我们已经了解什么是决策树了,可以进一步学习怎么去应用了。

sklearn建模的基本流程

1. 实例化数据,建立评估模型

clf = tree.DecisionTreeClassifier()

2. 通过模型接口训练模型

clf = clf.fit(X_train,y_train)

3.通过模型接口获得训练结果

result = clf.score(X_test,y_test)  

那么具体到实际操作中,应该怎么做呢?

1. 导入需要的算法库和模块

2. 分析和处理数据

3. 分训练集和测试集

4. 分析决策树的结果,并优化模型或导出数据

决策树的重要参数

1. Criterion

Criterion这个参数正是用来决定不纯度的计算方法的。sklearn提供了两种选择:

1) 输入”entropy“,使用信息熵

2) 输入”gini“,使用基尼系数

2. random_state

random_state用来设置分枝中的随机模式的参数,默认None,在高维度时随机性会表现更明显,低维度的数据(比如鸢尾花数据集),随机性几乎不会显现。输入任意整数,会一直长出同一棵树,让模型稳定下来。最优的节点不一定可以得到最优树,于是sklearn选择建不同的树,然后从中取最好的,在每次分枝时不使用全部特征,随机选取一部分特征,从中选取不纯度相关指标最优的作为分枝用的节点,这样每次生成的树也就不同了。

3. splitter

splitter也是用来控制决策树中的随机选项的,有两种输入值,输入”best",决策树在分枝时虽然随机,但是还是会优先选择更重要的特征进行分枝(重要性可以通过属性feature_importances_查看),输入“random",决策树在分枝时会更加随机,树会因为含有更多的不必要信息而更深更大,并因这些不必要信息而降低对训练集的拟合。这也是防止过拟合的一种方式。当你预测到你的模型会过拟合,用这两个参数来帮助你降低树建成之后过拟合的可能性。当然,树一旦建成,我们依然是使用剪枝参数来防止过拟合。

4. max_depth

限制树的最大深度,超过设定深度的树枝全部剪掉。

这是用得最广泛的剪枝参数,在高维度低样本量时非常有效。决策树多生长一层,对样本量的需求会增加一倍,所以限制树深度能够有效地限制过拟合。在集成算法中也非常实用。实际使用时,建议从=3开始尝试,看看拟合的效果再决定是否增加设定深度。

5. min_samples_leaf

min_samples_leaf限定,一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生,或者,分枝会朝着满足每个子节点都包含min_samples_leaf个样本的方向去发生。

一般搭配max_depth使用,在回归树中有神奇的效果,可以让模型变得更加平滑。这个参数的数量设置得太小会引起过拟合,设置得太大就会阻止模型学习数据。一般来说,建议从=5开始使用。如果叶节点中含有的样本量变化很大,建议输入浮点数作为样本量的百分比来使用。同时,这个参数可以保证每个叶子的最小尺寸,可以在回归问题中避免低方差,过拟合的叶子节点出现。对于类别不多的分类问题,=1通常就是最佳选择。

6. min_samples_split

min_samples_split限定,一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生。

7. max_features

max_features限制分枝时考虑的特征个数,超过限制个数的特征都会被舍弃。和max_depth异曲同工,max_features是用来限制高维度数据的过拟合的剪枝参数,但其方法比较暴力,是直接限制可以使用的特征数量而强行使决策树停下的参数,在不知道决策树中的各个特征的重要性的情况下,强行设定这个参数可能会导致模型学习不足。如果希望通过降维的方式防止过拟合,建议使用PCA,ICA或者特征选择模块中的降维算法。

8. min_impurity_decrease

min_impurity_decrease限制信息增益的大小,信息增益小于设定数值的分枝不会发生。

9. class_weight

完成样本标签平衡的参数。样本不平衡是指在一组数据集中,标签的一类天生占有很大的比例。比如说,在银行要判断“一个办了信用卡的人是否会违约”,就是是vs否(1%:99%)的比例。这种分类状况下,即便模型什么也不做,全把结果预测成“否”,正确率也能有99%。因此我们要使用class_weight参数对样本标签进行一定的均衡,给少量的标签更多的权重,让模型更偏向少数类,向捕获少数类的方向建模。该参数默认None,此模式表示自动给与数据集中的所有标签相同的权重。

10. min_weight_fraction_leaf

有了权重之后,样本量就不再是单纯地记录数目,而是受输入的权重影响了,因此这时候剪枝,就需要搭配min_ weight_fraction_leaf这个基于权重的剪枝参数来使用。另请注意,基于权重的剪枝参数(例如min_weight_ fraction_leaf)将比不知道样本权重的标准(比如min_samples_leaf)更少偏向主导类。如果样本是加权的,则使用基于权重的预修剪标准来更容易优化树结构,这确保叶节点至少包含样本权重的总和的一小部分。

总结一下

七个参数:Criterion,两个随机性相关的参数(random_state,splitter),四个剪枝参数(max_depth, ,min_sample_leaf,max_feature,min_impurity_decrease)
一个属性:feature_importances_
四个接口:fit,score,apply,predict

本文引用文献(排名不分先后):

  • 菜菜的sklearn
  • 数据科学导论(中国人民大学出版社出版)

Kaggle Titanic

Kaggle Titanic

题目要求

通过给我们乘客的信息,判断乘客的存活情况

代码示例和解释

from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
#导入所需要的库
train = pd.read_csv("../input/titanic/train.csv")
test = pd.read_csv("../input/titanic/test.csv")
train.info()
#读入训练集和测试集,并观察训练集的数据

 此时应该能观察到有11个特征

"PassengerId","Survived","Pclass","Sex","Age","Sibsp","Parch","Ticket","Fare","Cabin","Embarked"

train.drop(["Name", "Cabin", "Ticket"], inplace = True, axis = 1)
train.head()
#删除没用的数据("Name","Ticket"),由于"Cabin"缺失的数据太多无法补全且不重要,所以直接删去
#此前我们应该发现了数据集中的年龄是不完全的,所以需要我们自己补全(平均值即可)
train["Age"] = train["Age"].fillna(train["Age"].mean())
train = train.dropna() #并且删除多余的数据
train.info()
#再观测一次数据集,看是否完整
train.loc[:,"Sex"] = (train["Sex"] == "male").astype("int")
#我们需要把 Sex 特征从字符串类型变成整数类型,方便之后的调用
corrmat = train.corr()
plt.subplots(figsize = (12, 9))
sns.heatmap(corrmat, vmax = 0.9, square = True)
#我们用热力图观测各个特征之间的相关性

 我们不难发现,在热力图中,Survived 属性和 Fare、Sex、Age 的相关性很高,所以我们将他们作为决策树选择的特征。

features = ["Sex", "Age", "Fare", "PassengerId"]
Xtrain, Xtest, Ytrain, Ytest = train_test_split(train[features], train["Survived"], test_size = 0.3)
Xtrain.shape
#我们分类出 X 的训练和测试集, Y 的测试和训练集。并观察 X 训练集的格式
clf = DecisionTreeClassifier(criterion = "entropy",random_state = 25,max_depth = 8)
clf = clf.fit(Xtrain, Ytrain)
score = clf.score(Xtest, Ytest)
score
#建立决策树模型,并拟合训练数据,得到测试结果
test.drop(["Name", "Cabin", "Ticket", "Embarked", "Pclass", "Parch", "SibSp"], inplace = True, axis = 1)
test["Age"] = test["Age"].fillna(test["Age"].mean())
test["Fare"] = test["Fare"].fillna(test["Fare"].mean())
test.loc[:,"Sex"] = (test["Sex"] == "male").astype("int")
test
#测试集进行相同的处理
pre = clf.predict(test)
output = pd.DataFrame({"PassengerId": test.PassengerId, "Survived": pre})
output.to_csv("my_submission.csv", index = False)
#得到测试集的结果,并导入 csv 文件中,方便之后提交答案

然后你就可以得到 0.64354 的得分

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_24072.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据库作业一

MySQL数据库 MySQL官方提供了两个不同的版本: 1、社区版 (MySQL Commimity Server)免费,MySQL不提供任何技术支持(本文操作选用社区版) 2、商业版(MySQL Enterprise Edition)收费&a…

[Microsoft] 通过Microsoft Spotlight 中国站云技能挑战获取微软免费考试券

这是一篇关于微软Spotlight 推出学习活动的同时,如何获得免费考试券的方法,如果该文章在未来时间已经失效,那么建议你关注一下这个博客,有Azure China Cloud最新的消息会进行更新通知。 文章目录1. 所需准备注册账号2. 参加 Micro…

二十八、Hive集成HBase分析搜索引擎用户行为数据

我们已经知道,HBase数据库没有类SQL的查询方式,因此在实际的业务中操作和计算数据非常不方便。而Hive支持标准的SQL语法(HiveQL),若将Hive与HBase集成,则可以通过HiveQL直接对HBase的表进行读写操作,让HBase支持JOIN、GROUP等SQL查询语法,完成复杂的数据分析。甚至可以…

【电源设计】13开关电源仿真与应用

0.前言 本章主要是大概了解一下开关电源仿真与应用,开关电源仿真设计全过程:包括需求分析/控制/PWM。因为本人并不是专门做电源的,此部分内容仅作了解,并不专门去学习。 文章目录0.前言1.项目需求2.方案介绍2.1DCDC级&#xff08…

互联网重提内容为王?学Netflix(奈飞)做好内容营销

Netflix 成立于1997年,不久便一跃成为最受瞩目的流媒体服务网站。它为什么能在短短时间内获得如此巨大的成功呢?答案就在于它使用的超凡 内容营销策略 和方法 —— 数据驱动 、优化内容、以流量转化为目标。 内容为王众人皆知,内容营销是品牌…

【计算机毕业设计】java ssm高校计算机网络考试系统(源码+论文)

提供了一些今年最新计算机毕业设计源代码、文档及帮助指导,公众号:一点毕设,领取更多毕设资料。 随着计算机以及网络在教学领域的高速发展,为了加快数字化校园的进程,更好的实现现代化的教育改革,针对于当下…

手动制作满足SARscape要求的_dem数据

手动制作满足SARscape要求的_dem数据问题描述1 下载研究区的原始DEM数据,在envi中镶嵌裁剪,得到.dat格式的数据,然后用envi中的Original ENVI工具把.dat转成_dem1.1 下载研究区的原始DEM数据1.2 将.tif数据转成envi格式的.dat2. 能不能直接将…

WordPress开发中常用代码(必备)

很多人在WordPress开发中常用代码,WordPress 相比其它网站程序,最突出的优势:主题模板多,插件多,相关技术文章多,只要你想到的功能,都可以通过插件或者代码实现。现在分享下WordPress常用代码&a…

组合关系比依赖关系耦合性更强

首先说明,在这里我把“关联”、“组合”、“聚合”关系都统一当做“组合”关系来说的,但实际上聚合(has-a)是关联的一种,组合(cntains-a)也是关联的一种。如果想要知道三者之间的区别&#xff0…

实验二.常用网络命令

常用网络命令一、实验目的与要求学习常用网络命令的使用方法熟悉主机的基本网络配置 二、预习与准备网络常用命令及基本用法。主机的基本网络配置信息。 三、 实验内容 1.Ping命令 2.ipconfig命令 3.arp命令(地址转换协议) 4.traceroute命令 5.route命令…

花咲の姫君(異時層ツキハ) / 花咲(异时层妖刀)

目录基本资料面板值(无天冥加成)天冥奖励战斗宣言(VC)被动效果Another Sense技能珠子回到人物索引 基本资料 NS(5★)卡池 (Ver 2.13.50)ミヤビノカミの典録 天冥属性武器防具属性耐性异常耐性NS天火枪护腕风30%10%个性枪、东方、…

目标检测SSD学习笔记

目标检测SSD学习笔记 SSD: Single Shot MultiBox Detector Abstract. 我们提出了一种使用单一深度神经网络来检测图像中的对象的方法。我们的方法,命名为SSD,将边界框的输出空间离散化为一组默认框,每个特征地图位置具有不同的纵横比和比例…

BasicSR入门教程

BasicSR入门教程 1.安装环境 由于安装好的其他环境已经有了pytorch,那么新建环境时直接拷贝该环境就好 //复制环境 conda create --name my-basicsr --clone mmediting克隆项目 git clone https://github.com/XPixelGroup/BasicSR.git安装依赖包 cd BasicSR pi…

MyBatis--缓存

MyBatis的缓存 MyBatis的一级缓存 一级缓存是SqlSession级别的,通过 同一个SqlSession 查询的数据会被缓存,下次查询相同的数据,就会从缓存中直接获取,不会从数据库重新访问 import com.bijing.mybatis.mapper.CacheMapper; im…

二手商品交易网站

摘 要 本论文主要论述了如何使用JAVA语言开发一个二手商品交易网站,本系统将严格按照软件开发流程进行各个阶段的工作,采用B/S架构,面向对象编程思想进行项目开发。在引言中,作者将论述二手商品交易网站的当前背景以及系统开发的目…

大话西游服务端开服架设服务器搭建教程

大话西游服务端开服架设服务器搭建教程 大话西游一款回合制角色扮演手游,游戏内包含人族、仙族、魔族、鬼族四大种族,每个种族各有4个角色可供玩家选择。相信很多玩这款游戏的玩家也有不少想知道自己怎么可以开一个sf,自己当服主&#xff0c…

教学设计题-教学目标

(1)知识与技能目标 基础知识与基本技能 了解/理解(概念,性质) 掌握(方法,过程) 运用/会(----)剞劂问题 (2)过程与方法目标 通过(观察…

Linux篇【2】:shell命令初步认识,Linux权限(上)

目录 1、shell命令以及运行原理 2、Linux权限的概念 3、Linux权限管理 3.1、文件访问者的分类(人) 3.2、文件类型和文件权限属性(事物属性) 1、shell命令以及运行原理 Linux严格意义上说的是一个操作系统,我们称之为" 核心(kernel) " ,但…

Jenkins配置用户权限

前几篇讲了一下有关Jenkins的一系列的操作: 在linux上搭建jenkins,并进行所需的配置 Jenkins安装插件一直失败,报错SunCertPathBuilderException的解决方案 jenkins配置拉取git远程仓库的代码并进行自动化构建部署 怎么修改Jenkins的默认…

力扣周赛314-矩阵中和能被 K 整除的路径(动态规划)

解题思路:方案数问题动态规划问题。由于只能往下或右走,递归思考,每一点a[i][j]的方案数必由其上方a[i-1][j]或左侧a[i][j-1]得到。问题关键点在于统计的是能被K整除的路径数目,看一下示例1,如果走到(3,3&a…