深度学习 - PyTorch基本流程 (代码)

news/2024/4/29 15:02:33/文章来源:https://blog.csdn.net/BSCHN123/article/details/137010206

直接上代码

import torch 
import matplotlib.pyplot as plt 
from torch import nn# 创建data
print("**** Create Data ****")
weight = 0.3
bias = 0.9
X = torch.arange(0,1,0.01).unsqueeze(dim = 1)
y = weight * X + bias
print(f"Number of X samples: {len(X)}")
print(f"Number of y samples: {len(y)}")
print(f"First 10 X & y sample: \n X: {X[:10]}\n y: {y[:10]}")
print("\n")# 将data拆分成training 和 testing
print("**** Splitting data ****")
train_split = int(len(X) * 0.8)
X_train = X[:train_split]
y_train = y[:train_split]
X_test = X[train_split:]
y_test = y[train_split:]
print(f"The length of X train: {len(X_train)}")
print(f"The length of y train: {len(y_train)}")
print(f"The length of X test: {len(X_test)}")
print(f"The length of y test: {len(y_test)}\n")# 显示 training 和 testing 数据
def plot_predictions(train_data = X_train,train_labels = y_train,test_data = X_test,test_labels = y_test,predictions = None):plt.figure(figsize = (10,7))plt.scatter(train_data, train_labels, c = 'b', s = 4, label = "Training data")plt.scatter(test_data, test_labels, c = 'g', label="Test data")if predictions is not None:plt.scatter(test_data, predictions, c = 'r', s = 4, label = "Predictions")plt.legend(prop = {"size": 14})
plot_predictions()# 创建线性回归
print("**** Create PyTorch linear regression model by subclassing nn.Module ****")
class LinearRegressionModel(nn.Module):def __init__(self):super().__init__()self.weight = nn.Parameter(data = torch.randn(1,requires_grad = True,dtype = torch.float))self.bias = nn.Parameter(data = torch.randn(1,requires_grad = True,dtype = torch.float))def forward(self, x):return self.weight * x + self.biastorch.manual_seed(42)
model_1 = LinearRegressionModel()
print(model_1)
print(model_1.state_dict())
print("\n")# 初始化模型并放到目标机里
print("*** Instantiate the model ***")
print(list(model_1.parameters()))
print("\\n")# 创建一个loss函数并优化
print("*** Create and Loss function and optimizer ***")
loss_fn = nn.L1Loss()
optimizer = torch.optim.SGD(params = model_1.parameters(),lr = 0.01)
print(f"loss_fn: {loss_fn}")
print(f"optimizer: {optimizer}\n")# 训练
print("*** Training Loop ***")
torch.manual_seed(42)
epochs = 300
for epoch in range(epochs):# 将模型加载到训练模型里model_1.train()# 做 Forwardy_pred = model_1(X_train)# 计算 Lossloss = loss_fn(y_pred, y_train)# 零梯度optimizer.zero_grad()# 反向传播loss.backward()# 步骤优化optimizer.step()### 做测试if epoch % 20 == 0:# 将模型放到评估模型并设置上下文model_1.eval()with torch.inference_mode():# 做 Forwardy_preds = model_1(X_test)# 计算测试 losstest_loss = loss_fn(y_preds, y_test)# 输出测试结果print(f"Epoch: {epoch} | Train loss: {loss:.3f} | Test loss: {test_loss:.3f}")# 在测试集上对训练模型做预测
print("\n")
print("*** Make predictions with the trained model on the test data. ***")
model_1.eval()
with torch.inference_mode():y_preds = model_1(X_test)
print(f"y_preds:\n {y_preds}")
## 画图
plot_predictions(predictions = y_preds) # 保存训练好的模型
print("\n")
print("*** Save the trained model ***")
from pathlib import Path 
## 创建模型的文件夹
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents = True, exist_ok = True)
## 创建模型的位置
MODEL_NAME = "trained model"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME 
## 保存模型到刚创建好的文件夹
print(f"Saving model to {MODEL_SAVE_PATH}")
torch.save(obj = model_1.state_dict(), f = MODEL_SAVE_PATH)
## 创建模型的新类型
loaded_model = LinearRegressionModel()
loaded_model.load_state_dict(torch.load(f = MODEL_SAVE_PATH))
## 做预测,并跟之前的做预测
y_preds_new = loaded_model(X_test)
print(y_preds == y_preds_new)

结果如下

**** Create Data ****
Number of X samples: 100
Number of y samples: 100
First 10 X & y sample: X: tensor([[0.0000],[0.0100],[0.0200],[0.0300],[0.0400],[0.0500],[0.0600],[0.0700],[0.0800],[0.0900]])y: tensor([[0.9000],[0.9030],[0.9060],[0.9090],[0.9120],[0.9150],[0.9180],[0.9210],[0.9240],[0.9270]])**** Splitting data ****
The length of X train: 80
The length of y train: 80
The length of X test: 20
The length of y test: 20**** Create PyTorch linear regression model by subclassing nn.Module ****
LinearRegressionModel()
OrderedDict([('weight', tensor([0.3367])), ('bias', tensor([0.1288]))])*** Instantiate the model ***
[Parameter containing:
tensor([0.3367], requires_grad=True), Parameter containing:
tensor([0.1288], requires_grad=True)]*** Create and Loss function and optimizer ***
loss_fn: L1Loss()
optimizer: SGD (
Parameter Group 0dampening: 0differentiable: Falseforeach: Nonelr: 0.01maximize: Falsemomentum: 0nesterov: Falseweight_decay: 0
)*** Training Loop ***
Epoch: 0 | Train loss: 0.757 | Test loss: 0.725
Epoch: 20 | Train loss: 0.525 | Test loss: 0.454
Epoch: 40 | Train loss: 0.294 | Test loss: 0.183
Epoch: 60 | Train loss: 0.077 | Test loss: 0.073
Epoch: 80 | Train loss: 0.053 | Test loss: 0.116
Epoch: 100 | Train loss: 0.046 | Test loss: 0.105
Epoch: 120 | Train loss: 0.039 | Test loss: 0.089
Epoch: 140 | Train loss: 0.032 | Test loss: 0.074
Epoch: 160 | Train loss: 0.025 | Test loss: 0.058
Epoch: 180 | Train loss: 0.018 | Test loss: 0.042
Epoch: 200 | Train loss: 0.011 | Test loss: 0.026
Epoch: 220 | Train loss: 0.004 | Test loss: 0.009
Epoch: 240 | Train loss: 0.004 | Test loss: 0.006
Epoch: 260 | Train loss: 0.004 | Test loss: 0.006
Epoch: 280 | Train loss: 0.004 | Test loss: 0.006*** Make predictions wit the trained model on the test data. ***
y_preds:tensor([[1.1464],[1.1495],[1.1525],[1.1556],[1.1587],[1.1617],[1.1648],[1.1679],[1.1709],[1.1740],[1.1771],[1.1801],[1.1832],[1.1863],[1.1893],[1.1924],[1.1955],[1.1985],[1.2016],[1.2047]])*** Save the trained model ***
Saving model to models/trained model
tensor([[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True]])

第一个结果图
第二个结果图

点个赞支持一下咯~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1027787.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Day24:私信列表、私信详情、发送私信

测试用户:用户名aaa 密码aaa 查询当前用户的会话列表;每个会话只显示一条最新的私信;支持分页显示。 首先看下表结构: conversation_id: 用from_id和to_id拼接,小的放前面去(因为两个人的对话应该在一个会…

物联网实战--入门篇之(一)物联网概述

目录 一、前言 二、知识梳理 三、项目体验 四、项目分解 一、前言 近几年很多学校开设了物联网专业,但是确却地讲,物联网属于一个领域,包含了很多的专业或者说技能树,例如计算机、电子设计、传感器、单片机、网…

誉天华为认证云计算课程如何

HCIA-Cloud Computing 5.0 课程介绍:掌握华为企业级虚拟化、桌面云部署,具备企业一线部署实施及运维能力 掌握虚拟化技术、网络基础、存储基础等内容,拥有项目实施综合能力 满足企业虚拟化方案转型需求,应对企业日益多样的业务诉求…

快速上手Spring Cloud 十四:璀璨物联网之路

快速上手Spring Cloud 一:Spring Cloud 简介 快速上手Spring Cloud 二:核心组件解析 快速上手Spring Cloud 三:API网关深入探索与实战应用 快速上手Spring Cloud 四:微服务治理与安全 快速上手Spring Cloud 五:Spring …

JAVA面试八股文之集合

JAVA集合相关 集合?说一说Java提供的常见集合?hashmap的key可以为null嘛?hashMap线程是否安全, 如果不安全, 如何解决?HashSet和TreeSet?ArrayList底层是如何实现的?ArrayList listnew ArrayList(10)中的li…

【计算机网络】第 11、12 问:流量控制和可靠传输机制有哪些?

目录 正文流量控制的基本方法停止-等待流量控制基本原理滑动窗口流量控制基本原理 可靠传输机制1. 停止-等待协议2. 后退 N 帧协议(GBN)3. 选择重传协议(SR) 正文 流量控制涉及对链路上的帧的发送速率的控制,以使接收…

云数据仓库Snowflake论文完整版解读

本文是对于Snowflake论文的一个完整版解读,对于从事大数据数据仓库开发,数据湖开发的读者来说,这是一篇必须要详细了解和阅读的内容,通过全文你会发现整个数据湖设计的起初原因以及从各个维度(架构设计、存算分离、弹性…

FPGA高端项目:解码索尼IMX327 MIPI相机转HDMI输出,提供FPGA开发板+2套工程源码+技术支持

目录 1、前言2、相关方案推荐本博主所有FPGA工程项目-->汇总目录我这里已有的 MIPI 编解码方案 3、本 MIPI CSI-RX IP 介绍4、个人 FPGA高端图像处理开发板简介5、详细设计方案设计原理框图IMX327 及其配置MIPI CSI RX图像 ISP 处理图像缓存HDMI输出工程源码架构 6、工程源码…

fastadmin学习04-一键crud

FastAdmin 默认内置一个 test 表,可根据表字段名、字段类型和字段注释通过一键 CRUD 自动生成。 create table fa_test (id int unsigned auto_increment comment ID primary key,user_id int(10) default 0 null…

吴恩达深度学习笔记:浅层神经网络(Shallow neural networks)3.1-3.5

目录 第一门课:神经网络和深度学习 (Neural Networks and Deep Learning)第三周:浅层神经网络(Shallow neural networks)3.1 神经网络概述(Neural Network Overview)3.2 神经网络的表示(Neural Network Representation…

YOLOv8全网独家改进: 红外小目标 | 注意力改进 | 多膨胀通道精炼(MDCR)模块,红外小目标暴力涨点| 2024年3月最新成果

💡💡💡本文独家改进:多膨胀通道精炼(MDCR)模块,解决目标的大小微小以及红外图像中通常具有复杂的背景的问题点,2024年3月最新成果 💡💡💡红外小目标实现暴力涨点,只有几个像素的小目标识别率大幅度提升 改进结构图如下: 收录 YOLOv8原创自研 https://blog…

Day24 代码随想录(1刷) 回溯

目录 77. 组合 216. 组合总和 III 17. 电话号码的字母组合 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 你可以按 任何顺序 返回答案。 示例 1: 输入:n 4, k 2 输出: [[2,4],[3,4],[2,3],[1,2],[1,3],[1…

Vue生命周期,从听说到深入理解(全面分析)

每个 Vue 组件实例在创建时都需要经历一系列的初始化步骤,比如设置好数据侦听,编译模板,挂载实例到 DOM,以及在数据改变时更新 DOM。在此过程中,它也会运行被称为生命周期钩子的函数,让开发者有机会在特定阶…

基于深度学习的OCR,如何解决图像像素差的问题?

基于深度学习的OCR技术在处理图像像素差的问题时确实面临一定的挑战。图像像素差可能导致OCR系统无法准确识别文本,从而影响其精度和可靠性。尽管已经有一些方法如SRN-Deblur、超分SR和GAN系列被尝试用于解决这个问题,但效果并不理想。然而,这…

论文阅读-Policy Optimization for Continuous Reinforcement Learning

摘要 我们研究了连续时间和空间环境下的强化学习( RL ),其目标是一个具有折扣的无限时域,其动力学由一个随机微分方程驱动。基于连续RL方法的最新进展,我们提出了占用时间(专门针对一个折现目标)的概念,并展示了如何有效地利用它…

【超图 SuperMap3D】【基础API使用示例】51、超图SuperMap3D - 绘制圆|椭圆形面标注并将视角定位过去

前言 引擎下载地址:[添加链接描述](http://support.supermap.com.cn/DownloadCenter/DownloadPage.aspx?id2524) 绘制圆形或者椭圆形效果 核心代码 entity viewer.entities.add({// 圆中心点position: { x: -1405746.5243351874, y: 4988274.8462937465, z: 370…

关于异业联盟模式做成小程序的可行性分析

随着移动互联网的快速发展,小程序作为一种轻量级应用,受到了越来越多企业和用户的青睐。而异业联盟模式则是一种有效的商业合作方式,能够实现资源共享、优势互补和共同发展。将异业联盟模式做成小程序,不仅可以提高用户体验&#…

Pytorch的hook函数

hook函数是勾子函数,用于在不改变原始模型结构的情况下,注入一些新的代码用于调试和检验模型,常见的用法有保留非叶子结点的梯度数据(Pytorch的非叶子节点的梯度数据在计算完毕之后就会被删除,访问的时候会显示为None&…

react-navigation:

我的仓库地址:https://gitee.com/ruanjianbianjing/bj-hybrid react-navigation: 学习文档:https://reactnavigation.org 安装核心包: npm install react-navigation/native 安装react-navigation/native本身依赖的相关包: react-nativ…

时序预测 | Matlab实现SSA-BP麻雀算法优化BP神经网络时间序列预测

时序预测 | Matlab实现SSA-BP麻雀算法优化BP神经网络时间序列预测 目录 时序预测 | Matlab实现SSA-BP麻雀算法优化BP神经网络时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现SSA-BP麻雀算法优化BP神经网络时间序列预测(完整源码和数据…