深度学习之“制作自定义数据”--torch.utils.data.DataLoader重写构造方法。

news/2024/5/9 4:59:10/文章来源:https://blog.csdn.net/shine_Lee_/article/details/129189105

深度学习之“制作自定义数据”–torch.utils.data.DataLoader重写构造方法。

前言:

​ 本文讲述重写torch.utils.data.DataLoader类的构造方法对自定义图片制作类似MNIST数据集格式(image, label),用于自己的Pytorch神经网络模型运行,代码已整理打包上传网盘,文末下载。tensor数据格式(N,C,H,W)

  • N:Batch,批处理大小,表示一个batch中的图像数量

  • C:Channel,通道数,表示一张图像中的通道数

  • H:Height,高度,表示图像垂直维度的像素数

  • W:Width,宽度,表示图像水平维度的像素数

  • 例如下图输出一个批次的训练集数据就是一批次64张图片(N),3维通道数(C),一张图片高度32像素(H),一张图片宽度32像素(W)

在这里插入图片描述

步骤一

​ 对图片整理分类(python代码os库进行对文件夹创建和图片的移动到文件夹),以文件夹名为图片的种类名,如下图所示:

在这里插入图片描述

步骤二

​ 对所有种类文件夹进行遍历读入,将每个(图片的文件路径 )和(对应的标签)写入到txt文本中,结果为trian.txt 和 test.txt,作为训练集合测试集的数据准备。代码为CreateDataset01.py

# -*- coding: utf-8 -*-
# @Time : 2023/1/26/026 18:48
# @Author : LeeSheel
# @File : CreateDataset01.py
# @Project : 深度学习'''
生成训练集和测试集,保存在txt文件中本地电脑,只选取出3000张图片为训练集进行模型运行数据
'''import os
import random
train_ratio = 0.6
test_ratio = 1-train_ratio
train_list, test_list = [],[]  #创建两个个列表,里面存放  图片路径+‘\t’+图片标签
data_list = []rootdata = r"D:\FreeDesk\大创项目\手写藏文字母识别\手写藏文字母数据\总数据"for root,dirs,files in os.walk(rootdata):# print(root)# print(dirs)# print(files)#拼接每个图片的绝对文件路径:for i in range(int(len(files)*train_ratio)):# print(files[i])#输出的是每个图片的名称# print(root+"---"+files[i])  #shu输出每个每个图片的文件夹路径----图片名称# print(os.path.join(root, files[i]))  #拼接路径,# print(str(root).split("/")[-1])   #dui对root进行字符串切割,获得最后一个元素,代表每个图片的标签。class_flag = str(root).split("\\")[-1]  #biaoqain标签data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'train_list.append(data)for i in range(int(len(files) * train_ratio),len(files)):# print(i)class_flag = str(root).split("\\")[-1]  # biaoqain标签# print(class_flag)# print(files[i])data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'test_list.append(data)# print(train_list)
random.shuffle(train_list)
random.shuffle(test_list)with open('train.txt','w',encoding='UTF-8') as f:for train_img in train_list:f.write(str(train_img))with open('test.txt','w',encoding='UTF-8') as f:for test_img in test_list:f.write(test_img)## 随机抽取3000个作为本地train.txt   以及1000个作为本地test.txt# from random import sample
#
# print(sample(train_list, 30000)) # 随机抽取5个元素
# local_train_list = sample(train_list, 30000)
# print("dsdfsdfs")
# print(len(local_train_list))
# local_test_list = sample(test_list, 10000)
#
# with open('localtrain.txt','w',encoding='UTF-8') as f:
#     for train_img in local_train_list:
#             f.write(str(train_img))
#
# with open('localtest.txt','w',encoding='UTF-8') as f:
#     for test_img in local_test_list:
#         f.write(test_img)

得到txt结果:(文件路径与标签以空格隔开):

在这里插入图片描述

步骤三

​ 将步骤二得到的train.txt 和 test.txt 转化为train_loader 和 test_loader,重写LoadData类的构造方法,将train.txt文本转为train_dataset ,将test.txt转为test_dataset,最后再使用torch.utils.data.DataLoader()进行转为train_loader 和 test_loader: 就可以用于调用模型训练了。

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=64,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=64,shuffle=True)

重写LoadData类的构造方法代码(这里的transforms.Normalize()图像标准化,可以使用下文的python代码求出mean和std,填入标准化数值。),步骤三代码为 CreateDataloader02.py

# -*- coding: utf-8 -*-
# @Time : 2023/1/26/026 18:56
# @Author : LeeSheel
# @File : CreateDataloader02.py
# @Project : 深度学习
import torch
from PIL import Image
import torchvision.transforms as transforms
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torch.utils.data import Datasetclass LoadData(Dataset):def __init__(self, txt_path, train_flag=True):self.imgs_info = self.get_images(txt_path)self.train_flag = train_flagself.train_tf = transforms.Compose([# 随机旋转图片transforms.RandomHorizontalFlip(),# 将图片尺寸resize到32x32transforms.Resize((32, 32)),# 将图片转化为Tensor格式transforms.ToTensor(),# 正则化(当模型出现过拟合的情况时,用来降低模型的复杂度)transforms.Normalize((0.96934927, 0.9696228, 0.9695143), (0.124204025, 0.12326231, 0.12356147))  # 图像标准化])self.val_tf = transforms.Compose([# 将图片尺寸resize到32x32transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.96934927, 0.9696228, 0.9695143), (0.124204025, 0.12326231, 0.12356147))])def get_images(self, txt_path):with open(txt_path, 'r', encoding='utf-8') as f:imgs_info = f.readlines()imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))return imgs_infodef __getitem__(self, index):img_path, label = self.imgs_info[index]img = Image.open(img_path)img = img.convert('RGB')if self.train_flag:img = self.train_tf(img)else:img = self.val_tf(img)label = int(label)return img, labeldef __len__(self):return len(self.imgs_info)train_dataset = LoadData("train.txt", True)print("训练接数据个数:", len(train_dataset))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=64,shuffle=True)
for image, label in train_loader:print(image.shape)print(image)# img = transform_BZ(image)# print(img)print(label)breaktest_dataset = LoadData("test.txt", False)
print("测试集数据个数:", len(test_dataset))
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=64,shuffle=True)

求图片标准化transforms.Normalize()参数 代码

# -*- coding: utf-8 -*-
# @Time : 2023/1/31/031 18:18
# @Author : LeeSheel
# @File : 计算std和mea.py
# @Project : 深度学习
import numpy as np
import cv2
import os# img_h, img_w = 32, 32
img_h, img_w = 32, 32  # 经过处理后你的图片的尺寸大小
means, stdevs = [], []
img_list = []imgs_path = "D:\\0"  # 数据集的路径采用绝对引用
imgs_path_list = os.listdir(imgs_path)len_ = len(imgs_path_list)
i = 0
for item in imgs_path_list:img = cv2.imread(os.path.join(imgs_path, item))img = cv2.resize(img, (img_w, img_h))img = img[:, :, :, np.newaxis]img_list.append(img)i += 1print(i, '/', len_)imgs = np.concatenate(img_list, axis=3)
imgs = imgs.astype(np.float32) / 255.for i in range(3):pixels = imgs[:, :, i, :].ravel()  # 拉成一行means.append(np.mean(pixels))stdevs.append(np.std(pixels))# BGR --> RGB , CV读取的需要转换,PIL读取的不用转换
means.reverse()
stdevs.reverse()print("normMean = {}".format(means))
print("normStd = {}".format(stdevs))

代码下载:

链接:https://pan.baidu.com/s/1fa_gdLYXagu65P2uYpepqA?pwd=xx78
提取码:xx78

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_73194.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

推荐系统从入门到入门(3)——基于MapReuduce与Spark的分布式推荐系统构建

本系列博客总结了不同框架、不同算法、不同界面的推荐系统,完整阅读需要大量时间(又臭又长),建议根据目录选择需要的内容查看,欢迎讨论与指出问题。 目录 系列文章梗概 系列文章目录 三、MapReduce 1.MapReduce详…

【视频】海康摄像头、NVR网络协议简介

1、软硬件整体架构 2、涉及的网络协议 3、协议简介 3.1 海康私有协议 设备发现SADP:进行设备的发现、激活、修改网络参数、忘记密码等; SDK:4200、系统平台的接入前端设备,协议不对外开放,但对外提供接口库; ISAPI:Intelligent Security API(智能安全API),基于HTTP传输…

2023新的一年软件测试还值得学习吗?

最近因为疫情等各种原因,大厂裁员,失业等等频频受到关注。不解释,确实存在,各行各业都很难,但是,说软件测试行业不吃香,我还真不认同(不是为培训机构说好话,大环境不好&a…

Odoo丨Odoo框架源码研读三:异常处理与定制化开发

Odoo丨Odoo框架源码研读三:异常处理与定制化开发 Odoo源码研读的第三期内容:异常处理与定制化开发。 *异常处理* Odoo中的Exception是对Python内置异常做了继承和封装,设定了自己核心的几个Exception。 而对异常的处理和Python内置异常的…

Spring 之bean的生命周期

文章目录IOCBean的生命周期运行结果实例演示实体类实例化前后置代码初始化的前后置代码application.xml总结今天我们来聊一下Spring Bean的生命周期,这是一个非常重要的问题,Spring Bean的生命周期也是比较复杂的。IOC IOC,控制反转概念需要…

Flutter+【三棵树】

定义 在Flutter中和Widgets一起协同工作的还有另外两个伙伴:Elements和RenderObjects;由于它们都是有着树形结构,所以经常会称它们为三棵树。 这三棵树分别是:Widget、Element、RenderObject Widget树:寄存烘托内容…

SigmaPlot科学绘图工具:ROC曲线分析及AUC组间差异的显著性分析

目的 初步使用SigmaPlot科学绘图工具;进行ROC曲线绘制并分析检验变量AUC组间差异性是否显著 软件下载及安装 SigmaPlot下载安装按照这个教程即可:https://www.hhkxxw.com/24799.html 快速通道:SigmaPlot下载链接:百度网盘链接…

DC220V冲击继电器RCJ-3

系列型号 RCJ-2型冲击继电器; RCJ-2/48VDC冲击继电器 RCJ-2/110VDC冲击继电器 RCJ-2/220VDC冲击继电器 RCJ-2/100VAC冲击继电器 RCJ-2/127VAC冲击继电器 RCJ-2/220VAC冲击继电器 RCJ-3/220VAC冲击继电器 RCJ-3型冲击继电器 RCJ-3/127VAC冲击继电器 RCJ-3/100VAC冲…

FastCGI sent in stderr: "PHP message: PHP Fatal error

服务器php7.2卸载安装7.4之后,打开网站一直无法访问,查看nginx错误日志发现一直报这个错误:2023/02/23 11:12:55 [error] 4735#0: *21 FastCGI sent in stderr: "PHP message: PHP Fatal error: Uncaught ReflectionException: Class translator does not exist in …

Python四大主题之一【 Web】 编程框架

目前Python的网络编程框架已经多达几十个,逐个学习它们显然不现实。但这些框架在系统架构和运行环境中有很多共通之处,本文带领读者学习基于Python网络框架开发的常用知识,及目前的4种主流Python网络框架:Django、Tornado、Flask、Twisted。 …

100%BIM学员的疑惑:不会CAD可以学Revit吗?

在新一轮科技创新和产业变革中,信息化与建筑业的融合发展已成为建筑业发展的方向,将对建筑业发展带来战略性和全局性的影响。 建筑业是传统产业,推动建筑业科技创新,加快推进信息化发展,激发创新活力,培育…

web客户端-websocket

1、websocket简介 WebSocket是HTML5开始提供的一种在单个TCP连接上进行全双工通讯的协议。 WebSocket使得客户端和服务器之间的数据交换变得更加简单,允许服务端主动向客户端推送数据。在WebSocket API中,浏览器和服务器只需要完成一次握手&#xff0c…

python3.11.2安装 + pycharm安装

下载 :https://www.python.org/ 2.双击下载的软件: 3.进入安装界面 下一步,点击 是 上一步点击后就看到如下: 安装成功了,接下来检测一下:cmd 安装pycharm PyCharm是一种Python IDE(Integr…

Apifox-比postman更优秀的接口自动化测试平台

一、Apifox介绍 Apifox 是 API 文档、API 调试、API Mock、API 自动化测试一体化协作平台,定位 Postman Swagger Mock JMeter。通过一套系统、一份数据,解决多个系统之间的数据同步问题。只要定义好 API 文档,API 调试、API 数据 Mock、A…

你真的需要文档管理软件吗?

什么是文档管理软件? 文档管理软件 (DMS) 是一种数字解决方案,可帮助组织处理、捕获、存储、管理和跟踪文档。 通过严格管理您的关键业务信息,您可以开发以稳定、可预测、可衡量的方式启动、执行和完成的流程。 如果没有功能齐全的文档管理软…

从事Python自动化测试,30岁熬到月薪20K+,分享我的多年面试经…

年少不懂面试经,读懂已是测试人。 大家好,我是小码哥,一名历经沧桑,看透互联网行业百态的测试从业者,经过数年的勤学苦练,精钻深研究,终于从初出茅庐的职场新手成长为现在的测试老鸟&#xff0…

zabbix4.0安装部署

目录 1.1、添加 Zabbix 软件仓库 1.2、安装 Server/proxy/前端 1.3、创建数据库 1.4、导入数据 1.5、为 Zabbix server/proxy 配置数据库 1.6、 启动 Zabbix server 进程 1.7、zabbix前端配置 SELinux 配置 1.8、安装 Agent 1.9、启动zabbix 2.0、访问zabbix 1.1、添加…

ThinkPHP ^6图片操作进阶

图片裁剪、缩略、水印不再是TP框架系统内置的功能,需要安装。 目录 安装 图片处理 1.创建图片对象 2.获取图片属性 3.裁剪图像 4.生成缩略图 6.保存图像 7.水印 安装 使用composer在项目根目录打开命令行执行: composer require topthink/think…

mybatis-plus分页方式

拦截器(分页插件) 一 方式1:XxxMapper.selectPage 1 selectPage(page, null) 概述 MyBatisPlus中提供的(自带的)分页插件,非常简单,只需要简单的配置就可以实现分页功能。详细步骤: 第一步:&…

【Tips】通过背数据了解业务

学习资料:做了三年数据分析,给你的几点建议 1. 通过背数据了解业务 原文: 总结: 方法:每天早上去到公司第一件事情就是先背一遍最新的各种指标。原理: 数据敏感性就是建立在对数据的了解和熟悉上。业务的…