【神经网络手写数字识别-最全源码(pytorch)】

news/2024/5/11 15:00:47/文章来源:https://blog.csdn.net/qq_60498436/article/details/132130888

Torch安装的方法

在这里插入图片描述

学习方法

  • 1.边用边学,torch只是一个工具,真正用,查的过程才是学习的过程
  • 2.直接就上案例就行,先来跑,遇到什么来解决什么

Mnist分类任务:

  • 网络基本构建与训练方法,常用函数解析

  • torch.nn.functional模块

  • nn.Module模块

读取Mnist数据集

  • 会自动进行下载
# 查看自己的torch的版本
import torch
print(torch.__version__)
%matplotlib inline
# 前两步,不用管是在网上下载数据,后续的我们都是在本地的数据进行操作
from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)
import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

784是mnist数据集每个样本的像素点个数

from matplotlib import pyplot
import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)

在这里插入图片描述
全连接神经网络的结构
在这里插入图片描述在这里插入图片描述注意数据需转换成tensor才能参与后续建模训练

import torchx_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

torch.nn.functional 很多层和函数在这里都会见到

torch.nn.functional中有很多功能,后续会常用的。那什么时候使用nn.Module,什么时候使用nn.functional呢?一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些

import torch.nn.functional as Floss_func = F.cross_entropydef model(xb):return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs]  # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) 
bs = 64
bias = torch.zeros(10, requires_grad=True)print(loss_func(model(xb), yb))

创建一个model来更简化代码

  • 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
  • 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
  • Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
from torch import nnclass Mnist_NN(nn.Module):# 构造函数def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out  = nn.Linear(256, 10)self.dropout = nn.Dropout(0.5)#前向传播自己定义,反向传播是自动进行的def forward(self, x):x = F.relu(self.hidden1(x))x = self.dropout(x)x = F.relu(self.hidden2(x))x = self.dropout(x)#x = F.relu(self.hidden3(x))x = self.out(x)return x

在这里插入图片描述

net = Mnist_NN()
print(net)

在这里插入图片描述
可以打印我们定义好名字里的权重和偏置项

for name,parameter in net.named_parameters():print(name, parameter,parameter.size())

在这里插入图片描述

使用TensorDataset和DataLoader来简化

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)
  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as npdef fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()  # 训练的时候需要更新权重参数for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval() # 验证的时候不需要更新权重参数with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))

zip的用法

a = [1,2,3]
b = [4,5,6]
zipped = zip(a,b)
print(list(zipped))
a2,b2 = zip(*zip(a,b))
print(a2)
print(b2)
from torch import optim
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)

三行搞定!

train_dl,valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(100, model, loss_func, opt, train_dl, valid_dl)

在这里插入图片描述

correct = 0
total = 0
for xb,yb in valid_dl:outputs = model(xb)_,predicted = torch.max(outputs.data,1)total += yb.size(0)correct += (predicted == yb).sum().item()
print(f"Accuracy of the network the 10000 test imgaes {100*correct/total}")

![在这里插入图片描述](https://img-blog.csdnimg.cn/89e5e749b680426c9700aac9f93bf76a.png

后期有兴趣的小伙伴们可以比较SGD和Adam两种优化器,哪个效果更好一点

-SGD 20epoch 85%
-Adam 20epoch 85%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_152424.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux命令详解 | cd命令】Linux系统中用于更改当前工作目录的命令

文章标题 简介一,参数列表二,使用介绍1. 使用cd命令切换到特定目录2. 使用cd命令与路径相关的特殊字符3. 使用cd命令切换到包含空格的目录4. 使用cd命令切换到前一个和后一个目录5. 使用cd命令切换到用户的主目录6. 使用cd命令与绝对路径和相对路径 总结…

【项目流程】前端项目的开发流程

1. 项目中涉及的所有角色及其职责 - PM 产品经理 产品经理(Product Manager,简称PM)负责明确和定义产品的愿景和战略,与客户、用户、业务部门和其他利益相关者进行沟通,收集并分析他们的需求和期望。负责制定产品的详…

TCP三次握手,四次挥手理解

1. 三次握手 *三次握手(Three-way Handshake)*其实就是指建立一个TCP连接时,需要客户端和服务器总共发送3个包。进行三次握手的主要作用就是为了确认双方的接收能力和发送能力是否正常、指定自己的初始化序列号为后面的可靠性传送做准备。实…

前端学习---vue2--选项/数据--data-computed-watch-methods-props

写在前面: vue提供了很多数据相关的。 文章目录 data 动态绑定介绍使用使用数据 computed 计算属性介绍基础使用计算属性缓存 vs 方法完整使用 watch 监听属性介绍使用 methodspropspropsData data 动态绑定 介绍 简单的说就是进行双向绑定的区域。 vue实例的数…

MPU6050

偏航角(Yaw) 横滚角(ROll) 俯仰角(Pit) 误差 mpu6050里面有一个受力的东西 受重力影响的电容 某个导体就往下一点 根据fma就可以算出当前的加速度值 加速度传感器只输出加速度 知道重力加速度和重力的角度可…

C++入门之stl六大组件--List源码深度剖析及模拟实现

文章目录 前言 一、List源码阅读 二、List常用接口模拟实现 1.定义一个list节点 2.实现一个迭代器 2.2const迭代器 3.定义一个链表,以及实现链表的常用接口 三、List和Vector 总结 前言 本文中出现的模拟实现经过本地vs测试无误,文件已上传gite…

java: 非法字符: ‘\ufeff‘

遇到这种情况是编码转换问题 解决办法: 单个文件:可以先将格式转换为utf-16,然后在转换回utf-8 多个文件:在setting-file encodings将乱码的这个文件夹里的所有Java文件都设置utf-8格式就可以了

小成本大幅度增幅CNN鲁棒性,完美的结合GLCM+CNN

本文以实验为导向,使用vgg16GLCM实现一场精彩的新冠肺炎的分类识别,并且对比不加GLCM后的效果。在这之前,我们需要弄明白一些前缀知识和概念问题: GLCM(Gray-Level Co-occurrence Matrix),中文称…

比特鹏哥-数据类型和变量【自用笔记】

这里写目录标题 1.数据类型介绍字符,整型,浮点型,布尔类型 2.signed 和unsigned3.数据类型的取值范围sizeof 展示字节大小--- 计算机中单位:字节 4.变量 常量4.1 变量创建变量(数据类型 变量名)创建变量的时…

基于react-native的简单消息确认框showModel

基于react-native的简单消息确认框showModel 效果示例图组件代码ShowModel/index.jsx使用案例device.js安装线性渐变色 效果示例图 组件代码ShowModel/index.jsx import React, {forwardRef, useImperativeHandle, useState} from react; import {View,Text,Modal,TouchableOp…

2023,哪些大厂不再值钱?

2023年,摘下口罩的第一年,虽然经济复苏没那么强劲,但对于在资本寒冬中熬了许久的互联网科技股来说,春天的步伐好像越来越近了。今年以来,主要互联网科技公司的股价基本都涨了不少,尤其美国那边,…

ROS添加发布者和订阅者机制实现

一. ROS的节点和包 ✨Node: ROS的基本单位,实现某个功能的节点。比如实现超声波传感器就是一个节点,雷达传感器就可以是一个节点 ✨Package: 多个有联系的节点组成的单位,比如你要控制无人机姿态,可能需要…

【Linux命令详解 | pwd命令】Linux系统中用于显示当前工作目录的命令

文章标题 简介一,参数列表二,使用介绍1. pwd命令的基本使用2. pwd命令中的参数3. pwd命令的工作机制4. pwd命令的实际应用 总结 简介 pwd命令是Linux中的基础命令之一,使用该命令可以快速查看当前工作目录。在掌握Linux命令时,pw…

在Raspberry Pi 4上安装Ubuntu 20.04 + ROS noetic(不带显示器)

在Raspberry Pi 4上安装Ubuntu 20.04 ROS noetic(不带显示器) 1. 所需设备 所需设备: 树莓派 4 B 型 wifi microSD 卡:最小 32GB MicroSD 转 SD 适配器 (可选)显示器,鼠标等 2. 树莓派…

CDN安全面临的问题及防御架构

CDN安全 SQL注入攻击(各开发小组针对密码和权限的管理,和云安全部门的漏洞扫描和渗透测试) Web Server的安全(运营商和云安全部门或者漏洞纰漏第三方定期发布漏洞报告修复,例如:nginx版本号和nginx resol…

Spring5.2.x 源码使用Gradle成功构建

一 前置准备 1 Spring5.2.x下载 1.1 Spring5.2.x Git下载地址 https://gitcode.net/mirrors/spring-projects/spring-framework.git 1.2 Spring5.2.x zip源码包下载,解压后倒入idea https://gitcode.net/mirrors/spring-projects/spring-framework/-/…

《数据同步-NIFI系列》Nifi配置UpdateAttribute实现字符串时间戳转日期

Nifi配置UpdateAttribute实现字符串时间戳转日期 数据处理流程如下:查询源数据库,将Avro转为Json格式,然后使用EvaluateJsonPath修改字段名,最后使用replaceText将参数组成SQL,最后PutSQL。 一、字段串时间戳导致无法插…

转运相关的征兆,大家可以来看看

转运是一种喜讯,意味着运势将逐渐好转,人生会迎来一系列积极的变化。 虽然没有确切的科学根据可以证明转运的存在, 但是在许多传统文化和民俗中,人们都相信转运的征兆是实实在在的。 虽然无法确保这些征兆会在每种情况下都适用&am…

MySQL索引3——Explain关键字和索引使用规则(SQL提示、索引失效、最左前缀法则)

目录 Explain关键字 索引性能分析 Id ——select的查询序列号 Select_type——select查询的类型 Table——表名称 Type——select的连接类型 Possible_key ——显示可能应用在这张表的索引 Key——实际用到的索引 Key_len——实际索引使用到的字节数 Ref ——索引命…

测试工程师的工作

目录 1.何为软件测试工程师? 2.软件测试工程师的职责? 3.为什么要做软件测试? 4.软件测试的前途如何? 5.工具和思维谁更重要? 6.测试和开发相差大吗? 7.成为测试工程师的必备条件 8.测试的分类有哪…