PyTorch搭建LeNet训练集详细实现

news/2024/7/27 15:53:35/文章来源:https://blog.csdn.net/2301_78820199/article/details/136511115

一、下载训练集

导包

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

 ToTensor()函数:

把图像[heigh x width x channels] 转换为 [channels x height x width]

Normalize() 数据标准化函数:

最后一行是标准化数值计算公式

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)

参数解释: 

root='./data':数据集下载的路径,我下载到当前目录下的data文件夹,下载完成后会自动创建 

train=True:当前为训练集

download=True:下载数据集时设置为True,下载完成后改为False

transform=transform :设置对图像进行预处理的函数

运行下载数据集结果为: 

 下载完成后生成了data文件夹

二、导入训练集 

# 导入训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,shuffle=True, num_workers=0)

参数解释: 

        trainset:把刚刚下载的数据导入进来

        batch_size=36:一批数据的大小

        shuffle=True:训练集中的数据是否打乱(一般默认打乱)

        num_workers=0:载入数据的现成数,在lunix操作系统下,可以设置为别的参数,在windows操作系统系统下,默认为0.

三、下载测试集

# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000,shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_lable = test_data_iter.next()classes = ('plane', 'car', 'bird', 'cat',   # 数据集中的分类,设置为元组,不可变类'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

参数解释:

test_data_iter = iter(testloader):通过iter()函数把testloader转化成可迭代的迭代器
test_image, test_lable = test_data_iter.next():通过next()方法可以获得测试的图像和图像对应的标签值。

 四、查看导入的图片

在中间过程打印图片进行查看,后续会注释掉

def imshow(img):img = img / 2 + 0.5nping = img.numpy()plt.imshow(np.transpose(nping, (1, 2, 0)))plt.show()# print labels
print(' '.join('%5s' % classes[test_lable[j]] for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(test_image))

运行结果

 图片很模糊,因为像素很低。

上面识别出来的结果都对了。

 我遇到的问题:
一开始有结果但是没有图片,我以为时matplotlib的问题,我重新安装并且更新了版本,但是我再运行后报错更多了,报错提示我 AttributeError: module 'numpy' has no attribute 'bool',我就知道是numpy的问题了,我重新安装并且更新了版本结果还是不行,我百度了一下,发现不是越新的版本越好,我重新下载了1.23.2这个版本的numpy,下载完成后运行就出来结果了。

pip install numpy==1.23.2

这个也只是中间过程,后续会注释或者删了。


五、将创建的模型实例化

创建模型请看PyTorch搭建LeNet神经网络-CSDN博客

# 将创建的模型实例化
net = LeNet()  # 实例化
loss_fuction = nn.CrossEntropyLoss()  # 定义损失函数# 通过优化器将所有可训练的参数都进行训练,lr是learningrate学习率
optimizer = optim.Adam(net.parameters(), lr=0.001)#通过for循环实现训练过程,循环几次就是将训练集迭代多少次
for epoch in range(5):running_loss = 0.0  # 用来累加在学习过程中的损失for step, data in enumerate(trainloader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()   # 历时损失梯度清零。# forward + backward + optimizeoutputs = net(inputs)loss = loss_fuction(outputs, labels)  # 计算神经网络的预测值和真实标签之间的损失loss.backward()optimizer.step()  # step()函数实现参数更新# print statistics  打印数据的过程running_loss += loss.item()if step % 500 == 499:  # 每隔500步打印一次数据的信息with torch.no_grad():  # 上下文管理器outputs = net(test_image)predict_y = torch.max(outputs, dim=1)[1]accuracy = (predict_y == test_lable).sum().item() / test_lable.size(0)print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')# 将模型保存到文件夹中
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

 详细解释:

比较重点的单独解释了,其他的在注释中。

 optimizer.zero_grad()   # 历时损失梯度清零。

 ? 为什么每计算一个batch,就要调用一次 optimizer.zero_grad()函数

=> 通过清楚历史梯度,就会对计算的历史梯度进行累加。通过这个特性,能变相的实现一个很大的batch数值的训练(因为batch数值越大,训练效果越好)

 with torch.no_grad():  # 上下文管理器

 上下文管理器: 在接下来的计算过程中,不再去计算每个节点的误差损失梯度。

如果不调用这个函数,将会在测试过程中占用更多的算力,消耗更多的资源和占用更多的内存资源,导致内存容易崩。

print函数中打印参数解释:

print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))

epoch + 1:迭代到第几轮了

step + 1:某一轮的第几步

running_loss / 500:训练过程中500步平均训练误差

accuracy:准确率

运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1005650.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大载重无人机基础技术,研发一款50KG负重六旋翼无人机技术及成本分析

六旋翼无人机是一种多旋翼无人机,具有六个旋翼,通常呈“X”形布局。它采用电动串列式结构,具有垂直起降、悬停、前飞、后飞、侧飞、俯仰、翻滚等多种飞行动作的能力。六旋翼无人机通常被用于航拍、农业植保、环境监测、地形测绘等领域。 六旋…

【JavaScript】数据类型转换 ① ( 隐式转换 和 显式转换 | 常用的 数据类型转换 | 转为 字符串类型 方法 )

文章目录 一、 JavaScript 数据类型转换1、数据类型转换2、隐式转换 和 显式转换3、常用的 数据类型转换4、转为 字符串类型 方法 一、 JavaScript 数据类型转换 1、数据类型转换 在 网页端 使用 HTML 表单 和 浏览器输入框 prompt 函数 , 接收的数据 是 字符串类型 变量 , 该…

Linux本地搭建FastDFS系统

文章目录 前言1. 本地搭建FastDFS文件系统1.1 环境安装1.2 安装libfastcommon1.3 安装FastDFS1.4 配置Tracker1.5 配置Storage1.6 测试上传下载1.7 与Nginx整合1.8 安装Nginx1.9 配置Nginx 2. 局域网测试访问FastDFS3. 安装cpolar内网穿透4. 配置公网访问地址5. 固定公网地址5.…

uniapp封装统一请求(get和post)

uniapp封装请求 request.js文件 import Vue from vue // 全局配置 import settings from ./settings.js function computedBaseUrl(url) {// console.log(url);return (url.indexOf(http) -1 ? settings.baseUrl : ) url }// 发送请求 export default (options) > {const…

1688平台官方开发平台API接口接入|发布商品|订单查询|跨境API接口

《财经十一人》获悉,阿里巴巴(BABA.N)旗下中国B2B平台1688正布局跨境业务。 举措主要有二:一是提供跨境版API接口,可将1688的货盘导入各类有流量的平台,比如各国代采网站、服务商SaaS(软件服务…

30m二级分类土地利用数据Arcgis预处理及获取

本篇以武汉市为例,主要介绍将土地利用数据转换成武汉市内各区土地利用详情的过程以及分区统计每个区内各地类面积情况,后面还有制作过程中遇到的面积制表后数据过小的解决方法以及一些相关的知识点: 示例数据下载链接:数据下载链…

NCDA设计大赛中游戏美术设计命题参赛指南

未来设计师・全国高校数字艺术设计大赛(NCDA)正在如火如荼的进行中,各个命题单位也在陆续发出自己的参赛要求。不知道大学生们准备得怎么样。相信大家一定都在全网搜索往年的获奖作品, 今天我为大家整理了一些 NCDA 前几年的获奖作…

基于Redis自增实现全局ID生成器(详解)

本博客为个人学习笔记,学习网站与详细见:黑马程序员Redis入门到实战 P48 - P49 目录 全局ID生成器介绍 基于Redis自增实现全局ID 实现代码 全局ID生成器介绍 背景介绍 当用户在抢购商品时,就会生成订单并保存到数据库的某一张表中&#…

鸿蒙视频播放的实现

文章目录 前言播放效果视频播放的实现总结 一、前言 现在市面上很多应用都跟视频有关,那么在鸿蒙系统上怎么来播放视频呢,今天就讲解视频播放控件,让你也能快速地进行视频播放功能开发。 最后呢,我会提供一个鸿蒙中涉及的主要…

系统安全保证措施-word

【系统安全保证措施-各支撑材料直接套用】 一、 身份鉴别 二、 访问控制 三、 通信完整性、保密性 四、 抗抵赖 五、 数据完整性 六、 数据保密性 七、 应用安全支撑系统设计 软件全套资料下载进主页。

数字化转型导师坚鹏:科技创新产业发展研究及科技金融营销创新

科技创新产业发展研究及科技金融营销创新 课程背景: 很多银行存在以下问题: 不清楚科技创新产业的发展现状? 不知道科技金融有哪些成功的案例? 不知道科技金融如何进行营销创新? 课程特色: 以案例…

说说JVM的垃圾回收机制

简介 垃圾回收机制英文为Garbage Collection, 所以我们常常称之为GC。那么为什么我们需要垃圾回收机制呢?如果大家有了解过Java虚拟机运行时区域的组成(JVM运行时存在,本地方法栈,虚拟机方法栈,程序计数器,堆&#xf…

Uni-app跟学笔记(四):图片API、跨端兼容、组件注册、组件通信

文章目录 1)图片API1:图片上传2:图片预览 2)跨端兼容3)导航跳转4)组件及其生命周期1:注册组件2:组件生命周期 5)组件的通信1:父组件给子组件传值1-父组件2-子…

打造你的HTML5打地鼠游戏:零基础入门教程

🌟 前言 欢迎来到我的技术小宇宙!🌌 这里不仅是我记录技术点滴的后花园,也是我分享学习心得和项目经验的乐园。📚 无论你是技术小白还是资深大牛,这里总有一些内容能触动你的好奇心。🔍 &#x…

计算机组成原理实验报告1 | 实验1.1 运算器实验(键盘方式)

本文整理自博主大学本科《计算机组成原理》课程自己完成的实验报告。 —— *实验环境为学校机房实验箱。 目录 一、实验目的 二、实验内容 三、实验步骤及实验结果 Ⅰ、单片机键盘操作方式实验 1、实验连线(键盘实验) 2、实验过程 四、实验结果的…

CountDownLatch介绍和使用

1. CountDownLatch是什么 CountDownLatch 是 Java.util.concurrent 包中的一个同步工具类,用于控制线程的执行顺序。它的主要作用是让一个或多个线程等待其他线程完成操作后再继续执行。 2. CountDownLatch 类常用方法 CountDownLatch(int count) 是 CountDownLa…

String 底层是如何实现的?

1、典型回答 String 底层是基于数组实现的,并且数组使用了 final 修饰,不同版本中的数组类型也是不同的: JDK9 之前(不含JDK9) String 类是使用 char[ ](字符数组)实现的但 JDK9 之后&#xf…

安装PyTorch详细过程

安装anaconda 登录anaconda的官网下载,anaconda是一个集成的工具软件不需要我们再次下载。anaconda官网 跳转到这个页面如果你的Python版本正好是3.8版,那便可以直接根据系统去选择自己相应的下载版本就可以了。 但是如果你的Python版本号不是当前页面…

美摄科技对抗网络数字人解决方案

在数字化浪潮的推动下,企业对于高效、创新且具备高度真实感的数字化解决方案的需求日益迫切。美摄科技凭借其在人工智能和计算机视觉领域的深厚积累,推出了一款全新的对抗网络数字人解决方案,该方案能够为企业构建出表情和动作都极为逼真的数…

Unity Shader实现UI流光效果

效果: shader Shader "UI/Unlit/Flowlight" {Properties{[PerRendererData] _MainTex("Sprite Texture", 2D) "white" {}_Color("Tint", Color) (1, 1, 1, 1)[MaterialToggle] PixelSnap("Pixel snap", float…