深度学习基础知识-tf.keras实例: 加州房价预测

news/2024/5/20 2:06:31/文章来源:https://blog.csdn.net/pxy7896/article/details/131061194

参考书籍:《Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, 2nd Edition (Aurelien Geron [Géron, Aurélien])》

代码有修改,已测通。


简单顺序结构

这次得数据集比之前得简单,只包含数字型特征,没有ocean_proximity,也没有缺失值。

如果 sklearn.datasets.fetch_california_housing 报错 urllib.error.HTTPError: HTTP Error 403: Forbidden,那么下载文件cal_housing_py3.pkz放到 sklearn.datasets.get_data_home()下,这里是 C:\Users\用户名\scikit_learn_data
参考:https://blog.csdn.net/qq_44644355/article/details/107054585

from sklearn.datasets import fetch_california_housing, get_data_home
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScalerdef load_housing_data():housing = fetch_california_housing()# 默认划分是3:1,aka 75%train, 25%testX_train_full, X_test, y_train_full, y_test = train_test_split(housing.data, housing.target)X_train, X_valid, y_train, y_valid = train_test_split(X_train_full, y_train_full)scaler = StandardScaler()# fit_transform会计算数据的均值和方差,transform不会# 而且前者一般只在训练集进行,后者则是在训练集和测试集都可以用X_train = scaler.fit_transform(X_train)X_valid = scaler.transform(X_valid)X_test = scaler.transform(X_test)return X_train, X_valid, X_test, y_train, y_valid, y_test#print(get_data_home())X_train, X_valid, X_test, y_train, y_valid, y_test = load_housing_data()model = keras.models.Sequential([keras.layers.Dense(30, activation="relu", input_shape=X_train.shape[1:]),keras.layers.Dense(1)
])
model.compile(loss="mean_squared_error", optimizer="sgd")
history = model.fit(X_train, y_train, epochs=30, validation_data=(X_valid, y_valid))
mse_test = model.evaluate(X_test, y_test)
print(mse_test)
X_new = X_test[:3]
y_pred = model.predict(X_new)
print(y_pred)

复杂结构

单输入

在这里插入图片描述

input_ = keras.layers.Input(shape=X_train.shape[1:])
hidden1 = keras.layers.Dense(30, activation="relu")(input_)
hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
concat = keras.layers.concatenate(inputs=[input_, hidden2])
output = keras.layers.Dense(1)(concat)
model = keras.Model(inputs=[input_], outputs=[output])
#model.compile(loss="mse", optimizer=keras.optimizers.SGD(lr=1e-3))

多输入

在这里插入图片描述

input_A = keras.layers.Input(shape=[5], name="wide_input")
input_B = keras.layers.Input(shape=[6], name="deep_input")
hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
concat = keras.layers.concatenate([input_A, hidden2])
output = keras.layers.Dense(1, name="output")(concat)
model = keras.Model(inputs=[input_A, input_B], outputs=[output])
model.compile(loss="mse", optimizer=keras.optimizers.SGD(lr=1e-3))# 划分数据。0-4是输入wide_input的,2-最后是输入deep_input的
X_train_A, X_train_B = X_train[:, :5], X_train[:, 2:]
X_valid_A, X_valid_B = X_valid[:, :5], X_valid[:, 2:]
X_test_A, X_test_B = X_test[:, :5], X_test[:, 2:]
X_new_A, X_new_B = X_test_A[:3], X_test_B[:3]
# 训练
history = model.fit((X_train_A, X_train_B), y_train, epochs=30,validation_data=((X_valid_A, X_valid_B), y_valid))
mse_test = model.evaluate((X_test_A, X_test_B), y_test)
print(mse_test)
y_pred = model.predict((X_new_A, X_new_B))
print(y_pred)

如果需要增加一个输出,如下图所示,可以这样改:
在这里插入图片描述

# 其他保持不变
aux_output = keras.layers.Dense(1, name="aux_output")(hidden2)
model = keras.Model(inputs=[input_A, input_B], outputs=[output, aux_output])
model.compile(loss=["mse","mse"], loss_weights=[0.9, 0.1], optimizer=keras.optimizers.SGD(lr=1e-3))
# 假设aux_output预测的也是同样的东西
history = model.fit([X_train_A, X_train_B], [y_train, y_train], epochs=30,validation_data=([X_valid_A, X_valid_B], [y_valid, y_valid]))
# 此时有总loss和每个输出的loss
total_loss, main_loss, aux_loss = model.evaluate([X_test_A, X_test_B], [y_test, y_test])
print((total_loss, main_loss, aux_loss))
# 预测结果也会有多个
y_pred_main, y_pred_aux = model.predict([X_new_A, X_new_B])
print((y_pred_main, y_pred_aux))

如果需要动态调整网络,比如在某些情况下需要进入循环或者分支,那可以写一个新的类,如下面所示。这样的好处是网络组织更加灵活,而且summary()只能打印层的列表,不能传递层之间的连接方式;缺点是不能clone或者保存(不能用hdf5格式,只能save_weights和load_weights勉强保存一下),有时候也可能出错。

class WideAndDeepModel(keras.Model):def __init__(self, units=30, activation="relu", **kwargs):super().__init__(**kwargs) # handles standard args (e.g., name)self.hidden1 = keras.layers.Dense(units, activation=activation)self.hidden2 = keras.layers.Dense(units, activation=activation)self.main_output = keras.layers.Dense(1)self.aux_output = keras.layers.Dense(1)# 这里可以写loop if什么的def call(self, inputs):input_A, input_B = inputshidden1 = self.hidden1(input_B)hidden2 = self.hidden2(hidden1)concat = keras.layers.concatenate([input_A, hidden2])main_output = self.main_output(concat)aux_output = self.aux_output(hidden2)return main_output, aux_outputmodel = WideAndDeepModel()

保存模型可以用pickle或者joblib的dump,也可以直接:

model.save("xxx.h5")

这里使用HDF5格式保存架构、超参数和每层的参数,也会保存optimizer。等再载入到时候可以用:

model = keras.models.load_model("xxx.h5")

有时候需要早点停止,可以加Callbacks。比如ModelCheckpoint就保存了某些时间点模型的checkpoints。默认是每个epoch结束时。

# build and compile the model
checkpoint_cb = keras.callbacks.ModelCheckpoint("my_keras_model.h5", save_best_only=True)
history = model.fit(X_train, y_train, epochs=10,alidation_data=(X_valid, y_valid), callbacks=[checkpoint_cb])
model = keras.models.load_model("my_keras_model.h5") # roll back to best model

此时model只保存了最好的模型。

另外,也可以使用EarlyStopping,指如果在验证集上,若干个(patience参数)epoch没有进步了,就停止训练。也可以同时使用checkpoints和earlystopping。

early_stopping_cb = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)history = model.fit(X_train, y_train, epochs=100, validation_data=(X_valid, y_valid),callbacks=[checkpoint_cb, early_stopping_cb])

也可以写自定义的callbacks。可以选的时间点有:on_train_begin(), on_train_end(), on_epoch_begin(), on_epoch_end(), on_batch_begin(), and on_batch_end()。还可以在test和predict种插入callbacks,前者是evaluate()调用的,后者是predict()调用的。

class PrintValTrainRatioCallback(keras.callbacks.Callback):def on_epoch_end(self, epoch, logs):print("\nval/train: {:.2f}".format(logs["val_loss"] / logs["loss"]))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_312475.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设备指纹系列--前端篇

基础篇请看:设备指纹系列–基础篇 我们接着前文继续写关于设备指纹前端接入方面的内容。话不多说,直接步入正题。 我们会在下文展示5种前端接入的方式,包括web接入、安卓接入、ios接入、微信小程序接入以及支付宝小程序接入。 Web接入 第…

VSCode+GDB+Qemu调试ARM64 linux内核

俗话说,工欲善其事 必先利其器。linux kernel是一个非常复杂的系统,初学者会很难入门。 如果有一个方便的调试环境,学习效率至少能有5-10倍的提升。 为了学习linux内核,通常有这两个需要 可以摆脱硬件,方便的编译和…

Java基础学习+面向对象

一,基础概念介绍 1.1Java跨平台原理(一次编译,处处运行) Java 源代码经过编译,生成字节码文件,交由 Java 虚拟机来执行,不同得系统有不同得JVM,借助JVM 实现跨平台。就比如说我们在 Windows 下…

【Linux】13. 文件操作

1. 重新认识文件 经过之前的linux命令操作、进程相关概念的学习,我们对于文件也并不陌生 首先需要明确以下概念: 即使是空文件,也要在磁盘当中占据空间文件 文件内容 文件属性文件操作 对文件内容的操作 或者 对文件属性的操作 或者 二者…

ChatGPT训练一次要耗多少电?

如果开个玩笑:问ChatGPT最大的贡献是什么? “我觉得它对全球变暖是有一定贡献的。”知名自然语言处理专家、计算机科学家吴军在4月接受某媒体采访时如是说。 随着ChatGPT引爆AIGC,国内外巨头纷纷推出自己的AI大模型,大家为人工智…

跨境电商独立站搭建-跨境电商源码网站开发部署,独立站技术

跨境电商独立站是指在国际互联网上建立并拥有自己独立的电商网站,在该网站上进行跨境电商业务,包括产品展示、交易处理、支付结算、物流配送等全流程。相较于在第三方平台上开店,跨境电商独立站具有更高的自主权和品牌形象,能够更…

Redis 高级数据结构 HyperLogLog

介绍 HyperLogLog(Hyper[ˈhaɪpə(r)])并不是一种新的数据结构(实际类型为字符串类型),而是一种基数算法,通过HyperLogLog可以 利用极小的内存空间完成独立总数的统计,数据集可以是IP、Email、ID等。如果你负责开发维护一个大型的网站,有一天…

Java实现Mqtt收发消息

Java实现Mqtt收发消息 文章目录 Java实现Mqtt收发消息windows mqtt 平台服务搭建mqtt 客户端工具:mqttbox整体代码结构mqtt基础参数配置类mqtt客户端连接mqtt接收的消息处理类对应的MqttService注解和MqttTopic注解 MqttGateway 发送消息指定topic接收处理方法 java…

Servlet Cookie基本概念和使用方法

目录 Cookie 介绍 Cookie 主要有两种类型:会话 Cookie 和持久 Cookie。 Cookie使用步骤 使用Servlet和Cookie实现客户端存储的登录功能示例: LoginServlet类 index.jsp 删除Cookie 浏览器中查看Cookie的方法 Cookie 介绍 Cookie 是一种在网站和…

测试老鸟总结,自动化测试难点挑战应对方法,我的进阶之路...

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 Python自动化测试&…

不定积分练习

不定积分练习 在看视频的时候遇到了一道比较有趣的题,在这里给大家分享一下。 题目 计算 ∫ ( 1 x − 1 x ) e x 1 x d x \int(1x-\dfrac 1x)e^{x\frac 1x}dx ∫(1x−x1​)exx1​dx 解: \qquad 原式 ∫ e x 1 x d x ∫ x ( 1 − 1 x 2 ) e x 1…

Promise-用法

目录 1.处理异步的几种方案 2.理解 3.promise状态:初始化 4.执行异步任务 5.执行异步任务成功 6.执行异步任务失败 7.执行异步任务成功-返回 8.执行异步任务失败-返回 1.处理异步的几种方案 纯粹callback,会剥夺函数return的能力promise&#xf…

【Java基础学习打卡01】计算机概述

目录 引言一、计算机是什么?1.计算机vs计算器2.计算机定义 二、计算机发展简史三、计算机分类四、计算机基本工作原理1.冯诺依曼2.冯诺依曼原理 总结 引言 其实我们在学习Java编程之前应该要对计算机有所了解,这里的了解不是说我们日常接触电脑就算是了…

postgres篇---docker安装postgres,python连接postgres数据库

postgres篇---docker安装postgres,python连接postgres数据库 一、docker安装postgres1.1 安装Docker:1.2 从Docker Hub获取PostgreSQL镜像1.3 创建PostgreSQL容器1.4 访问PostgreSQL 二. python连接postgres数据库2.1 connect连接2.2 cursor2.3 excute执…

docker-compose 搭建 zipkin 服务端

目录 基于docker-compose搭建服务端 数据库 服务器 docker-compose.yaml 问题 测试 基于docker-compose搭建服务端 数据库 我这边存储选择了Mysql存储,新建了 zipkin库,数据库脚本如下 -- -- Copyright 2015-2019 The OpenZipkin Authors -- -- Li…

Springboot3 + SpringSecurity + JWT + OpenApi3 实现认证授权

Springboot3 SpringSecurity JWT OpenApi3 实现双token 目前全网最新的 Spring Security JWT 实现双 Token 的案例!收藏就对了,欢迎各位看友学习参考。此项目由作者个人创作,可以供大家学习和项目实战使用,创作不易&#xff…

linux 部署mysql

本文介绍下Centos7中mysql的安装(Centos7以下版本中有些命令和centos7中有些不同,安时需注意下自己的linux版本) 事先准备 1、查看系统中是否自带安装mysql yum list installed | grep mysql ![在这里插入图片描述](https://img-blog.csdnimg.cn/e322b2f4036c4d9…

电力能耗监测系统是如何运作的

电力能耗监测系统数据的采集主要通过多功能仪表、通讯管理机、通讯协议实现能耗数据的采集,能耗数据上传后经大数据计算实现能耗数据的展示,满足用户对能耗监测的需求。下面对电力能耗监测系统的是怎么采集数据的展开介绍: 1.多功能仪表 对高…

中国的互联网技术有多厉害?

1 很多人没有意识到,中国的互联网技术是相当厉害的。 给大家举几个例子。 我和朋友聊天的时候,手机上的app都在“侧耳倾听”,聊天的一些关键字很快就会出现在手机浏览器的搜索栏中。 携程会给我自动推荐景点,美团会给我推荐美食&…

【FLoyd算法(Floyd到底做什么用 进来看看)】牛的旅行【进来学习做一道题的步骤-1.读题根据样例 概括出题意-2.分析算法-3.优化算法】

专注 效率 记忆 预习 笔记 复习 做题 欢迎观看我的博客,如有问题交流,欢迎评论区留言,一定尽快回复!(大家可以去看我的专栏,是所有文章的目录)   文章字体风格: 红色文字表示&#…