Pytorch深度学习——线性回归实现 04(未完)

news/2024/5/17 17:14:53/文章来源:https://blog.csdn.net/weixin_42521185/article/details/126829055

文章目录

  • 1 问题假设
  • 2 步骤
  • 3 学习使用Pytorch的API来搭建模型
    • 3.1 nn.Model
    • 3.2 优化器类
    • 3.3 评估模式和训练模式
    • 3.4 使用GPU
  • data和item的区别

1 问题假设

假设我们的基础模型就是y = wx+b,其中w和b均为参数,我们使用y = 3x+0.8来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8。

2 步骤

  1. 准备数据: x就是随机生成的数,y=3*x+0.8 得到的数据。(也就是说,这里使用的y真实值 是无噪声的,那么可以知道只要epoch次数够就能训练出比较真实的值)
  2. 初始化要训练的参数
  3. 带入参数计算,并且用均方差代表loss
  4. 梯度清零反向传播tensor.grad中保存的是梯度,更新参数
import torchlearning_rate = 0.01
# 1 准备数据
# y = 3x+0.8x = torch.rand([500, 1])  # 500行1列的随机数,范围是0-1
y_true = x*3 + 0.8# 初始化两个要训练的参数
w = torch.rand([1, 1], requires_grad=True)  # 1行1列
b = torch.zeros(1, requires_grad=True)
# 或者写成这样: b = torch.tensor(0, requires_grad=True, dtype=torch.float32)print(w, b)
# 4 通过循环,反向传播,更新参数for i in range(2500):# 2 通过模型计算y_pred# 3 计算lossy_predict = torch.matmul(x, w) + bloss = (y_true - y_predict).pow(2).mean()  # 均方误差if w.grad is not None:w.grad.data.zero_()  # 归零 (就地修改)if b.grad is not None:b.grad.data.zero_()loss.backward()  # 反向传播w.data = w.data - learning_rate * w.gradb.data = b.data - learning_rate * b.gradprint("w, b, loss", w.item(), b.item(), loss.item())

以上代码就是一个模型训练的过程,训练模型的本质就是训练了参数w和b

输出如下:(因为输出有2500行+,所以只截取最后的结果)
在这里插入图片描述
可以看到,最后w收敛到了2.95517635345459, b最后收敛到了0.823337733745575,非常接近w=3b=0.8

3 学习使用Pytorch的API来搭建模型

3.1 nn.Model

nn.Model 是torch.nn 提供的一个类,是Pytorch中 微我们自定义网络的一个基类,在这个类中定义了很多有用的方法,让我们在继承这个类 定义网路的时候非常简单。

  1. __init__ 需要 调用super方法,继承附列的属性和方法。
  2. forward方法必须实现,用来定义我们的网络的前向计算的过程。
  • 用y = wx+b 的模型举例如下:
import torch
from torch import nnclass Lr(nn.Module):def __init__(self):super(lr, self).__init__()  # 继承父类的init的参数self.linear = nn.Linear(1, 1)def forward(self, x):out = self.linear(x)return out       

其中:这一行代码是固定的。

 super(lr, self).__init__()  # 继承父类的init的参数
  1. nn.Linear 为 torch 预定好的线性模型,也被称为全连接层,传入的参数为输入的数量和输出的数量(in_features, out_features),是不算(batch_size)的列数的。
  2. nn.Module 定义了 __call__方法,即类Lr的实例,实现的就是调用forward方法,能够直接被传入参数调用,实际上调用的是forward方法并传入参数。

示例:

model = Lr()  # 实例化模型
pred = model(x)  # 传入数据,计算结果

3.2 优化器类

优化器(optimizer),可以理解为pytorch中封装好了的用来更新参数的方法,比如常见的随机下降(SGD)和Adam。

  • 优化器都是由torch.optim提供的:
torch.optim.SGD(参数,学习率)
torch.optim.Adam(参数,学习率)
  1. 参数可以使用model.parameters() 来获取,获取模型中所有 requires_grad=True 的参数。
  2. 优化器使用方法:
    ①实例化
    ②所有参数的梯度置零
    ③反向传播计算梯度
    ④更新参数值
optimizer = optim.SGD(model.parameters(), lr=1e-3)  # 1. 实例化
optimizer.zero_grad()  # 2. 梯度置零
loss.backward()  # 3. 计算梯度
optimizer.step()  # 4. 更新参数的值
  • 写代码——调用API来实现线性模型
import torch
import torch.nn as nn
from torch.optim import SGD# 0. 准备数据
x = torch.rand([500, 1])
y_true = 3*x + 0.8# 1. 定义模型
class MyLinear(nn.Module):def __init__(self):super(MyLinear, self).__init__()self.linear = nn.Linear(1, 1)def forward(self, x):out = self.linear(x)return out# 2. 实例化模型,优化器实例化,loss实例化
my_linear = MyLinear()
optimizer = SGD(my_linear.parameters(), 0.001)
loss_fn = nn.MSELoss()# 3. 循环,进行梯度下降,参数的更新
for i in range(5000):# 得到预测值y_predict = my_linear(x)loss = loss_fn(y_predict, y_true)# 梯度置零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()if i%200 == 0:print(loss.item(), list(my_linear.parameters()))

在这里插入图片描述

可以看到,用API来实现线性回归的收敛效果没有手动实现的好。(不知道为啥)

3.3 评估模式和训练模式

model.eval()  # 表示设置模型为评估模式,即预测模式
model.train(mode=True)  # 表示设置模型为训练模式

在目前的线性模型中上述没有什么区别,但是在一些训练和预测时参数不同的模型中,比如说是Dropout, BatchNorm 等存在时,就需要告诉模型是训练还是在预测。

3.4 使用GPU

  1. 判断GPU是否可用: torch.cuda.is_available()
torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)
>> device(type='cuda', index=0)  # 使用GPU
>> device(type='cpu')  # 使用CPU
  1. 把模型参数和 input数据转换为cuda的支持类型 ( 要转的话要一起转,不能一个在cpu上跑 一个在GPU上跑)
model.to(device)
x_true.to(device)
  1. 在GPU上计算结果也为cuda的数据类型,需要转化为numpy或者cpu的tensor类型(就是说 要把gpu上的值转到cpu上进行一些求均值等等的操作)
predict = predict.cpu().detach().numpy()

detach() 的作用相当于 data,但是detach()是深拷贝,data是取值是浅拷贝。

  • 总结:模型放到GPU,那么输入x和输出y_true也要放在GPU,模型的参数也要GPU(如果是内部参数就不需要),最后得到的输出 y_predict 也是GPU类型的。

在GPU上执行程序:
(1)自定义的参数和数据,需要转化为cuda支持的tensor
(2)model需要转化为cuda支持的model
(3)执行的结果需要和cpu的tensor进行计算的时候:
a. tensor.cpu() 把cuda的tensor转化为CPU的tensor

data和item的区别

.data返回的是一个tensor
.item()返回的是一个具体的数值。
注意:对于元素不止一个的tensor列表,使用item()会报错

list(my_linear.parameters()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_9277.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

0.django部署(基础知识)

我们前面的代码都是在我们自己的电脑(通常是Windows操作系统)上面运行的,因为我们还处于开发过程中。 当我们完成一个阶段的开发任务后,就需要把我们开发的网站服务,给真正的用户使用了。 那就需要我们的 网站 部署在…

【二次分配问题】基于遗传算法 (GA)、粒子群优化 (PSO) 和萤火虫算法 (FA) 求解二次分配( QAP)问题(MATLAB 实现)

目录 1 概述 3 Matlab代码及文章阅读 4 运行结果 4.1 萤火虫算法 4.2 粒子群优化算法 4.3 遗传算法 5 参考文献 1 概述 目前,该问题已经得到深入的研究,进化策略(evolutionstrategies)、遗传算法(genetic algorithms)、遗传规划(geneticprogramm…

警惕利用「以太坊合并」的 3 种骗局

原文作者:茉莉 距离以太坊合并还有不到 6 小时,这条被视作下一代互联网 Web3.0 底层基础设施的区块链网络将彻底改变共识机制,从工作量证明的 PoW 机制转向权益证明的 PoS。 在合并即将到来前,去中心化安全网络市场 PolySwarm 创…

各语言转wasm-js调用

起源是 我司应该是抄袭某家player , 也用wasm做的 , 所以我也研究一下 关于标题 我估计需要大家一起完善了 , 我只会讲一下 go c 别的都不会 webassembly( wasm ) 可以编译的如图 我想起我这边应用啊 也就无非播放器~~ 本地文件压缩啊加密啊或直接就上传了, 或者在操作数据…

RestHighLevelClient创建索引时报错[299 Elasticsearch-7.12.1

RestHighLevelClient创建索引时报错[299 Elasticsearch-7.12.1出现原因 : 这是因为在使用create方法时 , 会有两个选择 , 其中一个已经过时了 client.indices().create(request, RequestOptions.DEFAULT); 其中的create方法 , 有两个版本 , 有一个显示已经过时了 , 两个方法虽然…

蜂蜜什么时候喝,才可以获得蜂蜜更大的好处?真可以治疗咳嗽?

中秋节刚过去不久,家里面的礼品多的是不是可以开超市了?中国人讲究一个“礼”字,逢年过节、探望故友病友手里不带点东西就会难受。中秋节这样带有美好祝愿的节日自然也是中国人送礼的最佳时间之一。 ​ 编辑切换为居中 添加图片注释,不超过…

Google Chrome Privacy Sandbox All In One

Google Chrome Privacy Sandbox All In OneGoogle Chrome Privacy Sandbox All In OneGoogle Chrome 隐私沙盒chrome://settings/privacySandbox With Privacy Sandbox trials, sites can deliver the same browsing experience using less of your info. That means more priv…

需要在html中加CSS,怎么加

在html中加CSS有三种方式 一种是直接写到标签上的style属性里面 <divid"mydiV"style"width:200px;border:1pxsolid#f00;margin:0;"></div> 一种是写到head标签里面的style标签里面 <styletype"text/css"> #mydiV{ width:2…

C++ 01 内存模型

内存分区的示意图。一般内存主要分为&#xff1a;代码区、常量区、静态区&#xff08;全局区&#xff09;、堆区、栈区这几个区域。 什么是代码区、常量区、静态区&#xff08;全局区&#xff09;、堆区、栈区&#xff1f; 代码区&#xff1a;存放程序的代码&#xff0c;即CPU执…

springboot 整合dubbo3开发rest应用

一、前言 作为微服务治理生态体系内的重要框架 dubbo&#xff0c;从出身到现在历经了十多年的市场检验而依旧火热&#xff0c;除了其自身优秀的设计&#xff0c;高性能的RPC性能&#xff0c;以及依托于springcloud-alibaba的这个背后强劲的开源团队支撑&#xff0c;在众多的微…

MongoDB6安装配置详解

官网下载地址&#xff1a; https://www.mongodb.com/try/download/community?tckdocs_server 打开后是这样的&#xff1a; 鼠标滑到上图红色箭头位置&#xff0c;可以看到最新版本目前是6.0.1&#xff0c;点击download下载即可&#xff0c;这里下载的是Windows版本。 下载好后…

vue插槽---作用域插槽(三)

编译作用域:模板中的变量,在模板对应的实例中查找相应的变量和数据。通俗的说就是父级模板里的所有内容都是在父级作用域中编译的;子模板里的所有内容都是在子作用域中编译的。 作用域插槽:带参数的插槽,子组件提供给父组件参数,父组件决定其展示形式替换插槽标签。 为什…

哈希原理及模拟实现并封装unordered系列关联式容器

目录一、哈希1. 哈希概念2. 哈希冲突3. 哈希函数4. 哈希冲突的解决闭散列线性探测二次探测开散列开散列与闭散列比较二、哈希表哈希表的实现三、封装unordered系列关联式容器1. 封装unordered_set2. 封装unordered_map四、哈希表的应用1. 位图概念2. 应用3. 位图的实现2. 布隆过…

springboot客户关系管理系统源码 CRM小程序源码

CRM客户关系管理系统源码 crm小程序源码 基于springbootvue MySQL数据库开发的客户关系管理系统。 客户全流程高效管理&#xff0c;客户资料管理&#xff0c;客户跟踪管理&#xff0c;订单、合同管理&#xff0c;回款及交付管理等功能。 功能介绍 1、系统管理&#xff1a;员工…

基于STM32单片机和AD9850的智能DDS函数信号发生器

CSDN话题挑战赛第2期 参赛话题&#xff1a;学习笔记 文章目录1、整体设计2、硬件方案3、软件程序4、实物验证1、整体设计 有一天&#xff0c;我在浏览CSDN时看到一篇关于 AD9850 的帖子。AD9850是一款可以产生1hz到40mhz左右正弦波的芯片。淘宝的产品经销商能够将芯片与提供 T…

第二章-使用KNN和GBDT进行收入的预测分析

本文是《从零开始学python数据分析与挖掘》的第二章学习心得&#xff0c;相关数据可以从对应的官方数据库获取。 提供给你的只有一份收入相关的xlsx&#xff0c;你需要通过里面的数据进行年收入的预测。 1.数据预处理 首先读取数据&#xff0c;查看是否存在缺失值。对于存在…

关于模糊理论及简单应用

关于模糊理论及简单应用 1.开始 最近导师让我了解一下模糊理论,思考能不能结合现有技术实现创新点.这篇博客主要记录一下这两天对模糊理论的学习,以及做的一个小demo,希望如果有研究相关方面的大佬能留言相互交流学习. 之前用模糊c均值聚类的时候了解过scikit-fuzzy,这次发现…

(14.1)Zotero常用功能:导入题录、参考文献

(14.1)Zotero常用功能&#xff1a;导入题录、参考文献 文章目录一、插件1.1、Zotfile1.2、Zotfile配置2、translators_CN3、zotero-pdf-translate4、jasminum5、zotero-better-bibtex-Sponsor&#xff08;待更新&#xff09;二、导入题录(知网为例)三、参考文献样式1、样式选择…

隐写术——PNG文件隐藏payload

0x01 PNG文件格式 PNG文件基本上由两部分组成: 文件头、文件数据块 文件头也叫署名域 用来标识这是一个PNG格式的文件&#xff0c;8字节长度&#xff0c;固定数据:89 50 4E 47 0D 0A 1A 0A 数据块: PNG定义了两种类型的数据块: 1:关键数据块(Critical Chunk):PNG文件必须包含&…

网课题库接口API—小白专用版本

网课题库接口API—小白专用版本 本平台优点&#xff1a;免费查题接口搭建 多题库查题、独立后台、响应速度快、全网平台可查、功能最全&#xff01; 1.想要给自己的公众号获得查题接口&#xff0c;只需要两步&#xff01; 2.题库&#xff1a;题库后台http://daili.jueguangzh…