金融数据_Scikit-Learn决策树(DecisionTreeClassifier)实例

news/2024/5/1 13:26:45/文章来源:https://blog.csdn.net/goufeng93/article/details/137427682

金融数据_Scikit-Learn决策树(DecisionTreeClassifier)实例

逻辑回归: 逻辑回归常被用于二分类问题, 比如涨跌预测。你可以将涨跌标记为类别, 然后使用逻辑回归进行训练。

决策树和随机森林: 决策树和随机森林是用于分类问题的强大模型。它们能够处理非线性关系, 并且对于特征的重要性有较好的解释。

实例数据

本实例截取了 “湖北宜化(000422)” 2015年08月06日 - 2015年12月31日的数据。

HBYH_000422_20150806_20151231.csv

Date,Code,Open,High,Low,Close,Pre_Close,Change,Turnover_Rate,Volume,MA5,MA10
2015-12-31,'000422,7.93,7.95,7.76,7.77,7.93,-0.020177,0.015498,13915200,7.86,7.85
2015-12-30,'000422,7.86,7.93,7.75,7.93,7.84,0.011480,0.018662,16755900,7.90,7.85
2015-12-29,'000422,7.72,7.85,7.69,7.84,7.71,0.016861,0.015886,14263800,7.90,7.81
2015-12-28,'000422,8.03,8.08,7.70,7.71,8.03,-0.039851,0.030821,27672800,7.91,7.78
2015-12-25,'000422,8.03,8.05,7.93,8.03,7.99,0.005006,0.021132,18974000,7.93,7.78
2015-12-24,'000422,7.93,8.16,7.87,7.99,7.92,0.008838,0.026487,23781900,7.85,7.72
2015-12-23,'000422,7.97,8.11,7.88,7.92,7.89,0.003802,0.042360,38033600,7.80,7.69
2015-12-22,'000422,7.86,7.93,7.76,7.89,7.83,0.007663,0.026929,24178700,7.73,7.68
2015-12-21,'000422,7.59,7.89,7.56,7.83,7.63,0.026212,0.030777,27633600,7.66,7.67
2015-12-18,'000422,7.71,7.74,7.57,7.63,7.74,-0.014212,0.024764,22234900,7.62,7.71
2015-12-17,'000422,7.58,7.75,7.57,7.74,7.55,0.025166,0.028054,25188400,7.59,7.77
2015-12-16,'000422,7.57,7.62,7.53,7.55,7.55,0.000000,0.020718,18601600,7.58,7.79
2015-12-15,'000422,7.63,7.66,7.52,7.55,7.62,-0.009186,0.025902,23256600,7.64,7.78
2015-12-14,'000422,7.40,7.64,7.36,7.62,7.51,0.014647,0.021005,18860100,7.68,7.76
2015-12-11,'000422,7.65,7.70,7.41,7.51,7.67,-0.020860,0.020477,18385900,7.80,7.73
2015-12-10,'000422,7.78,7.87,7.65,7.67,7.83,-0.020434,0.019972,17931900,7.95,7.69
2015-12-09,'000422,7.76,8.00,7.75,7.83,7.77,0.007722,0.025137,22569700,8.00,7.68
2015-12-08,'000422,8.08,8.18,7.76,7.77,8.24,-0.057039,0.036696,32948200,7.92,7.66
2015-12-07,'000422,8.12,8.39,7.94,8.24,8.23,0.001215,0.064590,57993100,7.84,7.64
2015-12-04,'000422,7.85,8.48,7.80,8.23,7.92,0.039141,0.100106,89881900,7.65,7.58
2015-12-03,'000422,7.42,8.09,7.38,7.92,7.43,0.065949,0.045416,40777500,7.43,7.52
2015-12-02,'000422,7.35,7.48,7.20,7.43,7.36,0.009511,0.015968,14337600,7.37,7.49
2015-12-01,'000422,7.28,7.39,7.23,7.36,7.33,0.004093,0.012308,11050700,7.41,7.48
2015-11-30,'000422,7.18,7.36,6.95,7.33,7.11,0.030942,0.020323,18247500,7.45,7.50
2015-11-27,'000422,7.59,7.59,6.95,7.11,7.60,-0.064474,0.027673,24846700,7.51,7.52
2015-11-26,'000422,7.63,7.73,7.58,7.60,7.63,-0.003932,0.024836,22299800,7.61,7.54
2015-11-25,'000422,7.56,7.64,7.51,7.63,7.59,0.005270,0.020919,18782900,7.61,7.54
2015-11-24,'000422,7.60,7.63,7.48,7.59,7.62,-0.003937,0.014867,13348200,7.56,7.53
2015-11-23,'000422,7.59,7.72,7.55,7.62,7.61,0.001314,0.028406,25505000,7.54,7.53
2015-11-20,'000422,7.59,7.71,7.53,7.61,7.59,0.002635,0.028277,25389100,7.52,7.53
2015-11-19,'000422,7.45,7.62,7.41,7.59,7.39,0.027064,0.038638,34691700,7.47,7.52
2015-11-18,'000422,7.53,7.54,7.38,7.39,7.51,-0.015979,0.014173,12725000,7.46,7.50
2015-11-17,'000422,7.53,7.63,7.44,7.51,7.50,0.001333,0.028640,25714500,7.51,7.50
2015-11-16,'000422,7.27,7.52,7.24,7.50,7.38,0.016260,0.016230,14572000,7.52,7.46
2015-11-13,'000422,7.49,7.55,7.36,7.38,7.54,-0.021220,0.029196,26214400,7.53,7.41
2015-11-12,'000422,7.65,7.68,7.49,7.54,7.61,-0.009198,0.026501,23794800,7.56,7.40
2015-11-11,'000422,7.57,7.64,7.52,7.61,7.57,0.005284,0.026113,23445900,7.54,7.37
2015-11-10,'000422,7.51,7.61,7.45,7.57,7.55,0.002649,0.024979,22427700,7.49,7.32
2015-11-09,'000422,7.51,7.62,7.45,7.55,7.53,0.002656,0.033367,29959500,7.39,7.31
2015-11-06,'000422,7.47,7.53,7.37,7.53,7.45,0.010738,0.037058,33273100,7.29,7.27
2015-11-05,'000422,7.34,7.54,7.32,7.45,7.37,0.010855,0.040463,36330200,7.24,7.24
2015-11-04,'000422,7.10,7.38,7.07,7.37,7.05,0.045390,0.034817,31260800,7.20,7.17
2015-11-03,'000422,7.08,7.13,7.02,7.05,7.06,-0.001416,0.014938,13412400,7.15,7.10
2015-11-02,'000422,7.11,7.26,7.05,7.06,7.26,-0.027548,0.016865,15142100,7.23,7.10
2015-10-30,'000422,7.22,7.38,7.10,7.26,7.24,0.002762,0.022821,20490200,7.25,7.10
2015-10-29,'000422,7.27,7.33,7.16,7.24,7.16,0.011173,0.025726,23098500,7.23,7.08
2015-10-28,'000422,7.32,7.40,7.09,7.16,7.42,-0.035040,0.035572,31938500,7.15,7.05
2015-10-27,'000422,7.21,7.48,7.08,7.42,7.18,0.033426,0.057658,51769300,7.04,7.01
2015-10-26,'000422,7.20,7.25,7.01,7.18,7.17,0.001395,0.036840,33077800,6.98,6.96
2015-10-23,'000422,6.84,7.22,6.81,7.17,6.80,0.054412,0.047169,42351500,6.95,6.93
2015-10-22,'000422,6.68,6.81,6.64,6.80,6.65,0.022556,0.020609,18503800,6.93,6.87
2015-10-21,'000422,7.08,7.11,6.61,6.65,7.09,-0.062059,0.039388,35365300,6.96,6.85
2015-10-20,'000422,7.00,7.09,6.94,7.09,7.03,0.008535,0.024472,21972900,6.98,6.81
2015-10-19,'000422,7.09,7.13,6.92,7.03,7.08,-0.007062,0.031262,28068800,6.94,6.72
2015-10-16,'000422,6.97,7.08,6.91,7.08,6.93,0.021645,0.039632,35584700,6.91,6.66
2015-10-15,'000422,6.77,6.94,6.75,6.93,6.77,0.023634,0.031645,28412700,6.82,6.59
2015-10-14,'000422,6.87,6.94,6.74,6.77,6.89,-0.017417,0.027226,24445500,6.74,6.55
2015-10-13,'000422,6.86,6.96,6.80,6.89,6.88,0.001453,0.028704,25771900,6.64,6.51
2015-10-12,'000422,6.62,6.91,6.58,6.88,6.61,0.040847,0.037037,33254300,6.50,6.49
2015-10-09,'000422,6.54,6.65,6.45,6.61,6.54,0.010703,0.018528,16635900,6.41,6.46
2015-10-08,'000422,6.45,6.70,6.37,6.54,6.26,0.044728,0.018857,16931000,6.35,6.44
2015-09-30,'000422,6.25,6.30,6.22,6.26,6.23,0.004815,0.007327,6579090,6.35,6.43
2015-09-29,'000422,6.30,6.32,6.18,6.23,6.40,-0.026562,0.008991,8072900,6.39,6.48
2015-09-28,'000422,6.35,6.42,6.25,6.40,6.34,0.009464,0.008824,7922890,6.48,6.47
2015-09-25,'000422,6.51,6.56,6.25,6.34,6.53,-0.029096,0.012584,11298800,6.51,6.45
2015-09-24,'000422,6.48,6.56,6.45,6.53,6.45,0.012403,0.011339,10180900,6.53,6.51
2015-09-23,'000422,6.51,6.60,6.41,6.45,6.67,-0.032984,0.015920,14294100,6.52,6.54
2015-09-22,'000422,6.58,6.73,6.54,6.67,6.58,0.013678,0.023356,20970200,6.56,6.60
2015-09-21,'000422,6.34,6.61,6.29,6.58,6.44,0.021739,0.017036,15295900,6.46,6.62
2015-09-18,'000422,6.52,6.58,6.30,6.44,6.44,0.000000,0.016622,14924700,6.39,6.62
2015-09-17,'000422,6.59,6.76,6.43,6.44,6.68,-0.035928,0.019517,17523900,6.48,6.62
2015-09-16,'000422,6.21,6.76,6.17,6.68,6.15,0.086179,0.019671,17662300,6.56,6.65
2015-09-15,'000422,6.24,6.38,6.05,6.15,6.26,-0.017572,0.015338,13771200,6.64,6.66
2015-09-14,'000422,6.89,6.95,6.18,6.26,6.87,-0.088792,0.021233,18559600,6.78,6.75
2015-09-11,'000422,6.87,6.96,6.77,6.87,6.84,0.004386,0.010853,9486290,6.85,6.79
2015-09-10,'000422,6.95,7.01,6.76,6.84,7.06,-0.031161,0.017423,15229100,6.76,6.74
2015-09-09,'000422,6.90,7.09,6.86,7.06,6.88,0.026163,0.028974,25325600,6.74,6.68
2015-09-08,'000422,6.65,6.91,6.55,6.88,6.62,0.039275,0.017858,15609100,6.69,6.67
2015-09-07,'000422,6.50,6.81,6.50,6.62,6.38,0.037618,0.017850,15602600,6.72,6.75
2015-09-02,'000422,6.45,6.88,6.30,6.38,6.74,-0.053412,0.022286,19480100,6.73,6.91
2015-09-01,'000422,6.88,6.99,6.67,6.74,6.81,-0.010279,0.025829,22576700,6.72,7.12
2015-08-31,'000422,6.90,6.97,6.71,6.81,7.07,-0.036775,0.018385,16069600,6.62,7.24
2015-08-28,'000422,6.75,7.08,6.71,7.07,6.67,0.059970,0.026692,23330800,6.65,7.44
2015-08-27,'000422,6.53,6.67,6.34,6.67,6.32,0.055380,0.022455,19627900,6.78,7.59
2015-08-26,'000422,6.31,6.77,6.09,6.32,6.25,0.011200,0.029963,26190200,7.08,7.76
2015-08-25,'000422,6.40,6.77,6.25,6.25,6.94,-0.099424,0.029492,25778600,7.52,7.96
2015-08-24,'000422,7.49,7.49,6.94,6.94,7.71,-0.099870,0.036552,31949900,7.86,8.18
2015-08-21,'000422,8.00,8.11,7.60,7.71,8.17,-0.056304,0.032199,28144800,8.23,8.33
2015-08-20,'000422,8.38,8.56,8.14,8.17,8.53,-0.042204,0.031764,27764200,8.40,8.38
2015-08-19,'000422,7.73,8.57,7.72,8.53,7.96,0.071608,0.052192,45619900,8.45,8.37
2015-08-18,'000422,8.81,8.86,7.92,7.96,8.80,-0.095455,0.056179,49105500,8.39,8.32
2015-08-17,'000422,8.49,8.83,8.42,8.80,8.52,0.032864,0.048161,42096900,8.50,8.35
2015-08-14,'000422,8.48,8.65,8.43,8.52,8.44,0.009479,0.041169,35985000,8.43,8.24
2015-08-13,'000422,8.20,8.45,8.15,8.44,8.24,0.024272,0.029768,26019600,8.37,8.16
2015-08-12,'000422,8.38,8.48,8.21,8.24,8.48,-0.028302,0.035421,30960700,8.30,8.08
2015-08-11,'000422,8.41,8.68,8.32,8.48,8.49,-0.001178,0.048444,42343900,8.26,8.03
2015-08-10,'000422,8.28,8.58,8.18,8.49,8.21,0.034105,0.041268,36071600,8.20,7.92
2015-08-07,'000422,8.15,8.28,8.08,8.21,8.07,0.017348,0.025855,22599800,8.05,7.81
2015-08-06,'000422,7.88,8.21,7.80,8.07,8.03,0.004981,0.020074,17546700,7.95,7.80

探索思路

这里只是简单示例, 目的在于熟悉 Scikit-Learn 中的决策树分类器使用方法, 无任何投资引导。

目标:

通过当日数值情况, 预测当日收盘涨跌, 如果 “涨跌幅(Change) >= 0”, 则用 1 表示, 如果 “涨跌幅(Change) < 0”, 则用 0 表示 (二分类标签)。

变量:

  1. 当日最高价

  2. 当日最低价

  3. 当日换手率

  4. 当日成交量

  5. 当日星期几 (星期对价格的影响)

  6. 当日 “短期均线(MA5)” 与 “长期均线(MA10)” 的关系, 如果 “MA5 > MA10”, 则用 1 表示, 如果 “MA5 = MA10”, 则用 0 表示, 如果 “MA5 < MA10”, 则用 -1 表示。

  7. 节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 “圣诞节(Christmas)” 和 “平安夜(Christmas Eve)” 做示例)。

导入 Pandas 相关模块

Pandas 是基于 NumPy 的一种工具, 该工具是为解决数据分析任务而创建的。Pandas 纳入了大量库和一些标准的数据模型, 提供了高效地操作大型数据集所需的工具。

Pandas 提供了大量能使我们快速便捷地处理数据的函数和方法。你很快就会发现, 它是使 Python 成为强大而高效的数据分析环境的重要因素之一。

import pandas as pd

导入 Scikit-Learn 相关模块

Scikit-Learn (以前称为 scikits.learn, 也称为 sklearn) 是针对 Python 编程语言的免费软件机器学习库。

它具有各种分类, 回归和聚类算法, 包括支持向量机, 随机森林, 梯度提升, K均值 和 DBSCAN, 并且旨在与 Python 数值科学库 NumPy 和 SciPy 联合使用。

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler

使用 Pandas 读取 CSV 数据

调用 Pandas 的 .read_csv 方法读取 CSV 数据:

其中 header 参数指定 CSV 文件的表头行, 这里的 header=0 表示表头行在 1 行, 如果 header=None 则表示数据没有列索引, Pandas 则会自动加上索引。

其中 sep 参数指定 CSV 文件的分隔符, 默认情况下都是以 “,” 作为分隔符, 这里的 sep=“,” 表示指定 CSV 文件的分隔符为 “,”。

还有 dtype 参数指定 CSV 某些特定列以特定的数据类型进行读取, 例如 dtype={“Close”:float, “Volume”:int} 表示 “Close” 列以 浮点(float) 类型读取, “Volume” 列以 整数(integer) 类型读取。

PDF = pd.read_csv("D:\\HBYH_000422_20150806_20151231.csv", header=0, sep=",")

输出 DataFrame 数据框:

print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
print(PDF)

输出:

[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csvDate     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10
0   2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85
1   2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85
2   2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81
3   2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78
4   2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78
..         ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...
94  2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08
95  2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03
96  2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92
97  2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81
98  2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80[99 rows x 12 columns]

转换 Pandas 中 DateFrame 各列数据类型

通常情况下, 为了避免计算出现数据类型的错误, 都需要重新转换一下数据类型。

# 转换 Pandas 中 DateFrame 数据类型。
PDF["Date"] =          PDF["Date"].astype("datetime64[ns]")
PDF["Open"] =          PDF["Open"].astype("float64")
PDF["High"] =          PDF["High"].astype("float64")
PDF["Low"] =           PDF["Low"].astype("float64")
PDF["Close"] =         PDF["Close"].astype("float64")
PDF["Pre_Close"] =     PDF["Pre_Close"].astype("float64")
PDF["Change"] =        PDF["Change"].astype("float64")
PDF["Turnover_Rate"] = PDF["Turnover_Rate"].astype("float64")
PDF["Volume"] =        PDF["Volume"].astype("int64")
PDF["MA5"] =           PDF["MA5"].astype("float64")
PDF["MA10"] =          PDF["MA10"].astype("float64")# 输出 Pandas 中 DataFrame 字段和数据类型。
print("[Message] Changed Pandas DataFrame Data Type:")
print(PDF.dtypes)

输出:

[Message] Changed Pandas DataFrame Data Type:
Date             datetime64[ns]
Code                     object
Open                    float64
High                    float64
Low                     float64
Close                   float64
Pre_Close               float64
Change                  float64
Turnover_Rate           float64
Volume                    int64
MA5                     float64
MA10                    float64
dtype: object

在 Pandas 的 DataFrame 中计算数据

编写 “判断股票短期均线和长期均线关系” 函数:

def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int:if (Short_MA >= Long_MA): return  1if (Short_MA == Long_MA): return  0if (Short_MA <= Long_MA): return -1# ==============================================# End of Function.

在 Pandas 的 DataFrame 中直接计算或调用自定义函数:

# 计算数据: 提取星期的索引, 从 0 到 6 (0 代表周一, 6 代表周日)。
PDF["Weekday(Idx)"] =    PDF["Date"].apply(lambda X: X.weekday())
# ..................................................
# 计算数据: 计算节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。
PDF["Festival"] = None
for Idx in PDF.index:if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,24): PDF.loc[Idx, "Festival"] = "Christmas_Eve" # -> 平安夜。if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,25): PDF.loc[Idx, "Festival"] = "Christmas"     # -> 圣诞节。
# ..................................................
# 计算数据: 判断股票涨跌。
PDF["Rise_Fall"] =       PDF["Change"].apply(lambda X: int(1) if X >= 0 else int(0))
# ..................................................
# 计算数据: 调用函数, 判断股票短期均线和长期均线关系。
PDF["MA_Relationship"] = PDF.apply(lambda X: MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"]), axis=1)# 输出计算好的 DataFrame 数据框。
print("[Message] Calculated DataFrame:")
print(PDF)

输出:

[Message] Calculated DataFrame:Date     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10  Weekday(Idx)   Festival  Rise_Fall  MA_Relationship
0  2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85             3       None          0                1
1  2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85             2       None          1                1
2  2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81             1       None          1                1
3  2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78             0       None          0                1
4  2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78             4  Christmas          1                1
..        ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...           ...        ...        ...              ...
94 2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08             2       None          0                1
95 2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03             1       None          0                1
96 2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92             0       None          1                1
97 2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81             4       None          1                1
98 2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80             3       None          1                1[99 rows x 16 columns]

在 Pandas 的 DataFrame 中将字符串类型的特征列转换为数值 (One-Hot Encoding)

pd.get_dummies() 是 Pandas 库中用于独热编码 (One-Hot Encoding) 的函数。它的作用是将分类 (离散) 变量的每个不同取值都拓展为一个新的二进制特征 (0 或 1), 从而方便机器学习模型处理。

# 函数签名:
pd.get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False, columns=None, sparse=False, drop_first=False, dtype=None)# 参数说明:
# - data: 要进行独热编码的 DataFrame 或 Series。
# - prefix: 生成的独热编码列的前缀。
# - prefix_sep: 生成的独热编码列的前缀和原始列名之间的分隔符。
# - dummy_na: 是否为原始数据中的缺失值生成独热编码列。
# - columns: 要进行独热编码的列的名称, 如果指定, 则只对这些列进行操作。
# - drop_first: 是否删除第一个独热编码列, 以避免共线性问题。

转换 Festival 特征列为数值:

# 将字符串类型的特征列转换为数值 (独热编码)。
PDF = pd.get_dummies(PDF, columns=["Festival"], drop_first=False)# 输出转换后的 DataFrame 数据框。
print("[Message] DataFrame After One-Hot Encoding:")
print(PDF)

输出:

[Message] DataFrame After One-Hot Encoding:Date     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10  Weekday(Idx)  Rise_Fall  MA_Relationship  Festival_Christmas  Festival_Christmas_Eve
0  2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85             3          0                1                   0                       0
1  2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85             2          1                1                   0                       0
2  2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81             1          1                1                   0                       0
3  2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78             0          0                1                   0                       0
4  2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78             4          1                1                   1                       0
..        ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...           ...        ...              ...                 ...                     ...
94 2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08             2          0                1                   0                       0
95 2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03             1          0                1                   0                       0
96 2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92             0          1                1                   0                       0
97 2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81             4          1                1                   0                       0
98 2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80             3          1                1                   0                       0[99 rows x 17 columns]

提取 标签(Label)列 和 特征(Feature)列

提取 标签(Label) 列:

# 提取 标签(Label) 列。
Y = PDF["Rise_Fall"]

提取 特征(Feature) 列:

# 提取 特征(Feature) 列。
X = PDF.drop(["Date", "Code", "Open", "Close", "Pre_Close", "Change", "MA5", "MA10", "Rise_Fall"], axis=1)

划分训练集和测试集(train_test_split) 以及 特征标准化(StandardScaler)

划分训练集和测试集(train_test_split):

# 数据集划分训练集和测试集(train_test_split)。
X_Train, X_Test, Y_Train, Y_Test = train_test_split(X, Y, test_size=0.2, random_state=42)

特征标准化(StandardScaler):

在机器学习中, fit_transform 和 transform 是用于数据预处理的常见方法, 它们的作用略有不同:

fit_transform: 该方法将同时拟合和转换数据。

  • 它会根据输入的数据计算所需的转换参数 (例如均值、标准差等), 然后将数据应用这些参数进行转换。

  • 在训练阶段, 通常使用 fit_transform 来对训练集进行拟合和转换。

  • 拟合过程会根据训练集数据计算并保存所需的转换参数, 然后将训练集数据应用这些参数进行转换。

  • 这样做的目的是确保在后续对测试集或新数据进行转换时使用相同的转换参数。

transform: 该方法仅对数据进行转换, 不进行拟合过程。

  • 它根据之前使用 fit_transform 得到的转换参数, 将这些参数应用于新的数据, 使其按照相同的转换方式进行处理。

  • 在测试阶段或对新数据应用模型时, 通常使用 transform 方法对测试集或新数据进行转换。

简而言之, fit_transform 方法用于拟合转换器并将数据进行转换, 而 transform 方法仅用于将数据按照已经拟合的转换器进行转换。

在代码中的具体应用上, 通常将 fit_transform 用于训练集的拟合和转换, 将 transform 用于测试集或新数据的转换, 以保证数据的一致性和正确的预处理操作。

# 特征标准化(StandardScaler)。
Obj_Scaler = StandardScaler()
X_Train_Scaled = Obj_Scaler.fit_transform(X_Train)
X_Test_Scaled = Obj_Scaler.transform(X_Test)

训练 决策树分类器(DecisionTreeClassifier) 模型

创建 决策树分类器(DecisionTreeClassifier):

# 创建 决策树分类器(DecisionTreeClassifier)。
DTC = DecisionTreeClassifier(random_state=42)

训练 决策树分类器(DecisionTreeClassifier) 模型:

# 训练 决策树分类器(DecisionTreeClassifier) 模型。
DTC.fit(X_Train_Scaled, Y_Train)# Value of Return:
# +----------------------------------------+
# |▼        DecisionTreeClassifier         |
# +----------------------------------------+
# | DecisionTreeClassifier(random_state=42)|
# +----------------------------------------+

使用 决策树分类器(DecisionTreeClassifier) 模型预测数据

# 在测试集上进行预测。
Y_Pred = DTC.predict(X_Test_Scaled)# 合并预测结果。
Result = X_Test.copy()
Result["Actually"] = Y_Test
Result["Prediction"] = Y_Predprint("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")
print(Result)

输出:

[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:High   Low  Turnover_Rate    Volume  Weekday(Idx)  MA_Relationship  Festival_Christmas  Festival_Christmas_Eve  Actually  Prediction
62  6.32  6.18       0.008991   8072900             1               -1                   0                       0         0           1
40  7.54  7.32       0.040463  36330200             3                1                   0                       0         1           1
95  8.68  8.32       0.048444  42343900             1                1                   0                       0         0           1
18  8.39  7.94       0.064590  57993100             0                1                   0                       0         1           1
97  8.28  8.08       0.025855  22599800             4                1                   0                       0         1           1
84  6.77  6.09       0.029963  26190200             2               -1                   0                       0         1           0
64  6.56  6.25       0.012584  11298800             4                1                   0                       0         0           1
42  7.13  7.02       0.014938  13412400             1                1                   0                       0         0           0
10  7.75  7.57       0.028054  25188400             3               -1                   0                       0         1           0
0   7.95  7.76       0.015498  13915200             3                1                   0                       0         0           1
31  7.54  7.38       0.014173  12725000             2               -1                   0                       0         0           0
76  7.09  6.86       0.028974  25325600             2                1                   0                       0         1           1
47  7.48  7.08       0.057658  51769300             1                1                   0                       0         1           1
26  7.64  7.51       0.020919  18782900             2                1                   0                       0         1           1
44  7.38  7.10       0.022821  20490200             4                1                   0                       0         1           0
4   8.05  7.93       0.021132  18974000             4                1                   1                       0         1           1
22  7.39  7.23       0.012308  11050700             1               -1                   0                       0         1           1
12  7.66  7.52       0.025902  23256600             1               -1                   0                       0         0           1
88  8.56  8.14       0.031764  27764200             3                1                   0                       0         0           1
73  6.95  6.18       0.021233  18559600             0                1                   0                       0         0           1

使用 accuracy_score 评估模型性能

# 评估模型性能。
Accuracy = accuracy_score(Y_Test, Y_Pred)
print("Accuracy:", Accuracy)
print("\n")# 输出分类报告。
print("Classification Report:")
print(classification_report(Y_Test, Y_Pred))

输出:

Accuracy: 0.5Classification Report:precision    recall  f1-score   support0       0.40      0.22      0.29         91       0.53      0.73      0.62        11accuracy                           0.50        20macro avg       0.47      0.47      0.45        20
weighted avg       0.47      0.50      0.47        20

完整代码

#!/usr/bin/python3
# Create By GF 2024-01-04# 在这个示例中, 我们使用 DecisionTreeClassifier 构建决策树模型。
# 为了处理字符串类型的特征列, 我们使用了 pd.get_dummies 进行独热编码。
# 然后, 我们对特征进行标准化, 并使用 train_test_split 将数据集划分为训练集和测试集。
# 最后, 我们训练模型、进行预测, 并评估模型性能。
# 请注意, 这只是一个基本的示例, 实际应用中你可能需要更多的特征工程、调参和模型评估。import datetime
# --------------------------------------------------
import pandas as pd
# --------------------------------------------------
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler# 编写 "判断股票短期均线和长期均线关系" 函数。
def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int:if (Short_MA >= Long_MA): return  1if (Short_MA == Long_MA): return  0if (Short_MA <= Long_MA): return -1# ==============================================# End of Function.if __name__ == "__main__":PDF = pd.read_csv("D:\\HBYH_000422_20150806_20151231.csv", header=0, sep=",")print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")print(PDF)# 转换 Pandas 中 DateFrame 数据类型。PDF["Date"] =          PDF["Date"].astype("datetime64[ns]")PDF["Open"] =          PDF["Open"].astype("float64")PDF["High"] =          PDF["High"].astype("float64")PDF["Low"] =           PDF["Low"].astype("float64")PDF["Close"] =         PDF["Close"].astype("float64")PDF["Pre_Close"] =     PDF["Pre_Close"].astype("float64")PDF["Change"] =        PDF["Change"].astype("float64")PDF["Turnover_Rate"] = PDF["Turnover_Rate"].astype("float64")PDF["Volume"] =        PDF["Volume"].astype("int64")PDF["MA5"] =           PDF["MA5"].astype("float64")PDF["MA10"] =          PDF["MA10"].astype("float64")# 输出 Pandas 中 DataFrame 字段和数据类型。print("[Message] Changed Pandas DataFrame Data Type:")print(PDF.dtypes)# 计算数据: 提取星期的索引, 从 0 到 6 (0 代表周一, 6 代表周日)。PDF["Weekday(Idx)"] =    PDF["Date"].apply(lambda X: X.weekday())# ..................................................# 计算数据: 计算节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。PDF["Festival"] = Nonefor Idx in PDF.index:if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,24): PDF.loc[Idx, "Festival"] = "Christmas_Eve" # -> 平安夜。if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,25): PDF.loc[Idx, "Festival"] = "Christmas"     # -> 圣诞节。# ..................................................# 计算数据: 判断股票涨跌。PDF["Rise_Fall"] =       PDF["Change"].apply(lambda X: int(1) if X >= 0 else int(0))# ..................................................# 计算数据: 调用函数, 判断股票短期均线和长期均线关系。PDF["MA_Relationship"] = PDF.apply(lambda X: MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"]), axis=1)# 输出计算好的 DataFrame 数据框。print("[Message] Calculated DataFrame:")print(PDF)# 将字符串类型的特征列转换为数值 (独热编码)。PDF = pd.get_dummies(PDF, columns=["Festival"], drop_first=False)# 输出转换后的 DataFrame 数据框。print("[Message] DataFrame After One-Hot Encoding:")print(PDF)# 提取 标签(Label) 列。Y = PDF["Rise_Fall"]# 提取 特征(Feature) 列。X = PDF.drop(["Date", "Code", "Open", "Close", "Pre_Close", "Change", "MA5", "MA10", "Rise_Fall"], axis=1)# 数据集划分训练集和测试集(train_test_split)。X_Train, X_Test, Y_Train, Y_Test = train_test_split(X, Y, test_size=0.2, random_state=42)# 特征标准化(StandardScaler)。Obj_Scaler = StandardScaler()X_Train_Scaled = Obj_Scaler.fit_transform(X_Train)X_Test_Scaled = Obj_Scaler.transform(X_Test)# 创建 决策树分类器(DecisionTreeClassifier)。DTC = DecisionTreeClassifier(random_state=42)# 训练 决策树分类器(DecisionTreeClassifier) 模型。DTC.fit(X_Train_Scaled, Y_Train)# Value of Return:# +----------------------------------------+# |▼        DecisionTreeClassifier         |# +----------------------------------------+# | DecisionTreeClassifier(random_state=42)|# +----------------------------------------+# 在测试集上进行预测。Y_Pred = DTC.predict(X_Test_Scaled)# 合并预测结果。Result = X_Test.copy()Result["Actually"] = Y_TestResult["Prediction"] = Y_Predprint("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")print(Result)# 评估模型性能。Accuracy = accuracy_score(Y_Test, Y_Pred)print("Accuracy:", Accuracy)print("\n")# 输出分类报告。print("Classification Report:")print(classification_report(Y_Test, Y_Pred))

其它

在这个示例中, 我们使用 DecisionTreeClassifier 构建决策树模型。

为了处理字符串类型的特征列, 我们使用了 pd.get_dummies 进行独热编码。

然后, 我们对特征进行标准化, 并使用 train_test_split 将数据集划分为训练集和测试集。

最后, 我们训练模型、进行预测, 并评估模型性能。

请注意, 这只是一个基本的示例, 实际应用中你可能需要更多的特征工程、调参和模型评估。

总结

以上就是关于 金融数据 Scikit-Learn决策树(DecisionTreeClassifier)实例 的全部内容。

更多内容可以访问我的代码仓库:

https://gitee.com/goufeng928/public

https://github.com/goufeng928/public

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1046082.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设计模式总结-组合模式

组合设计模式 模式动机模式定义模式结构组合模式实例与解析实例一&#xff1a;水果盘实例二&#xff1a;文件浏览 更复杂的组合总结 模式动机 对于树形结构&#xff0c;当容器对象&#xff08;如文件夹&#xff09;的某一个方法被调用时&#xff0c;将遍历整个树形结构&#x…

李沐26_网络中的网络NiN——自学笔记

全连接层太大 导致占用内存、占用计算带宽、很容易过拟合 NiN块 1.一个卷积层后跟着两个全连接层 2.步幅1&#xff0c;无填充&#xff0c;输出形状跟卷积层输出一样 3.起到全连接层的作用 NiN架构 1.无全连接层 2.交替使用NiN块和步幅为2的最大池化层&#xff0c;逐步减…

Weblogic任意文件上传漏洞(CVE-2018-2894)漏洞复现(基于vulhub)

&#x1f36c; 博主介绍&#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 hacker-routing &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【应急响应】 【Java、PHP】 【VulnHub靶场复现】【面试分析】 &#x1f389;点赞➕评论➕收…

【Linux网络编程】网络编程套接字(TCP服务器)

【Linux网络编程】网络编程套接字(TCP服务器) 目录 【Linux网络编程】网络编程套接字(TCP服务器)地址转换函数关于inet_ntoa 简单的TCP网络程序TCP sockot API详解socket()bind()listen()accept();connect 完整的TCP服务器代码&#xff08;线程池版&#xff09; 作者&#xff1…

【算法】双指针算法

个人主页 &#xff1a; zxctscl 如有转载请先通知 题目 1. 283. 移动零1.1 分析1.2 代码 2. 1089. 复写零2.1 分析2.2 代码 3. 202. 快乐数3.1 分析3.2 代码 4. 11. 盛最多水的容器4.1 分析4.2 代码 5. LCR 179. 查找总价格为目标值的两个商品5.1 分析5.2 代码 6. 15. 三数之和…

antd vue table控件的使用(三)

今天就讲讲Ant Design Vue下的控件----table表格&#xff08;分页、编辑和删除功能&#xff09; 结合项目中的需求&#xff0c;看看如何配置,让table即可以展示列表&#xff0c;又可以直接编辑数据。需求&#xff1a; &#xff08;1&#xff09;展示即将检查的数据列表&#…

持续交付工具Argo CD的部署使用

Background CI/CD&#xff08;Continuous Integration/Continuous Deployment&#xff09;是一种软件开发流程&#xff0c;旨在通过自动化和持续集成的方式提高软件交付的效率和质量。它包括持续集成&#xff08;CI&#xff09;和持续部署&#xff08;CD&#xff09;两个主要阶…

Day5-Hive的结构和优化、数据文件存储格式

Hive 窗口函数 案例 需求&#xff1a;连续三天登陆的用户数据 步骤&#xff1a; -- 建表 create table logins (username string,log_date string ) row format delimited fields terminated by ; -- 加载数据 load data local inpath /opt/hive_data/login into table log…

怎么保证缓存与数据库的最终一致性?

目录 零.读数据的标准操作 一.Cache aside Patten--旁路模式 二.Read/Write Through Pattern--读写穿透 三.Write Back Pattern--写回 四.运用canal监听mysql的binlog实现缓存同步 零.读数据的标准操作 这里想说的是不管哪种模式读操作都是一样的&#xff0c;这是一种统一…

【开源社区】openEuler、openGauss、openHiTLS、MindSpore

【开源社区】openEuler、openGauss、openHiTLS、MindSpore 写在最前面开源社区参与和贡献的一般方式开源技术的需求和贡献方向 openEuler 社区&#xff1a;开源系统官方网站官方介绍贡献攻略开源技术需求 openGauss 社区&#xff1a;开源数据库官方网站官方介绍贡献攻略开源技术…

Unity之Unity面试题(五)

内容将会持续更新&#xff0c;有错误的地方欢迎指正&#xff0c;谢谢! Unity之Unity面试题&#xff08;五&#xff09; TechX 坚持将创新的科技带给世界&#xff01; 拥有更好的学习体验 —— 不断努力&#xff0c;不断进步&#xff0c;不断探索 TechX —— 心探索、心进取…

目标检测——RCNN系列学习(二)Faster RCNN

接着上一篇文章&#xff1a;目标检测——RCNN系列学习(一&#xff09;-CSDN博客 主要内容包含&#xff1a;Faster RCNN 废话不多说。 Faster RCNN [1506.01497] Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks (arxiv.org)https://arxiv.…

跨域问题一文解决

&#x1f4dd;个人主页&#xff1a;五敷有你 &#x1f525;系列专栏&#xff1a;Vue ⛺️稳中求进&#xff0c;晒太阳 一、为什么会出现跨域的问题&#xff1f; 是浏览器的同源策略&#xff0c;跨域也是因为浏览器这个机制引起的&#xff0c;这个机制的存在还是在于安全…

Netty 入门应用之Http服务WebSocket

Netty实现Http服务 主要的变化是在初始化器中引入了新的编解码器 一些创建的类作用和Netty HelloWorld的小demo一样我这里就不再次重复了 1、Http服务端代码 public class HttpServer {public static void main(String[] args) {// 创建Reactor// 用来管理channel 监听事件 …

力扣HOT100 - 238. 除自身以外数组的乘积

解题思路&#xff1a; 当前位置的结果就是它左部分的乘积再乘以它右部分的乘积。因此需要进行两次遍历&#xff0c;第一次遍历用于求左部分的乘积&#xff0c;第二次遍历在求右部分的乘积的同时&#xff0c;再将最后的计算结果一起求出来。 class Solution {public int[] prod…

Vue3学习01 Vue3核心语法

Vue3学习 1. Vue3新的特性 2. 创建Vue3工程2.1 基于 vue-cli 创建项目文件说明 2.2 基于 vite 创建具体操作项目文件说明 2.3 简单案例(vite) 3. Vue3核心语法3.1 OptionsAPI 与 CompositionAPIOptions API 弊端Composition API 优势 ⭐3.2 setup小案例setup返回值setup 与 Opt…

【vue/uniapp】使用 smooth-signature 实现 h5 的横屏电子签名

通过github链接进行下载&#xff0c;然后代码参考如下&#xff0c;功能包含了清空、判断签名内容是否为空、生成png/jpg图片等。 签名效果&#xff1a; 预览效果&#xff1a; 下载 smooth-signature 链接&#xff1a;https://github.com/linjc/smooth-signature 代码参考&a…

nexus搭建maven与docker镜像的私有仓库

引言 通过nexus搭建maven与docker镜像的私有仓库,实现jar包与镜像动态更新、共享、存储。 一、nexus部署 通过docker-compose部署nexus name: java services:#############################环境#############################env-nexus:restart: always## 3.58.1image: so…

ANSYS 2024 R1 HFSS部分更新介绍(附下载)

1. 优化Layout component工作流 • 支持多区域 - 支持参数化弯曲定义的刚柔结合的PCB • Phi 网格可用 • 支持Mesh Fusion •简化创建复杂装配体的过程 2. 提升求解器速度 • 分布式矩阵汇编的内存使用率改进 ‐减少分布式矩阵求解器的RAM消耗 • 分布式稀疏直接求解器&am…

物联网实战--驱动篇之(六)4G通讯(Air780E)

目录 一、4G模块简介 二、AIR780E驱动程序 三、AIR780使用注意事项 四、结合MQTT传输测试 一、4G模块简介 4G应该是我们日常生活最常见的一种互联网通讯方式了&#xff0c;每个智能手机都配置了&#xff0c;不过手机的4G跟我们物联网领域要用的4G有点区别。首先是物联网采用…