PyTorch实战3:天气识别

news/2024/5/5 21:08:37/文章来源:https://blog.csdn.net/weixin_64215932/article/details/130383794
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P3周:天气识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

目录

    • 一、前期准备
        • 1、导入数据
        • 2、transforms.Compose详解
        • 3、图像处理(数据增强)
        • 4、加载数据集
        • 5、划分数据集
    • 二、构建简单的CNN网络
        • 1、torch.nn.Flatten()详解
        • 2、x.view()详解
        • 3、两者之间的区别
        • 4、案例结果展示
    • 三、指定图片进行预测
        • 1、⭐torch.squeeze()详解
        • 2、⭐torch.unsqueeze() 详解
        • 3、某张图片预测
    • 四、总结

本文案例学习主要有两点:

  • 本地读取并加载数据
  • 调用模型识别一张本地图片

一、前期准备

1、导入数据

读取指定目录下的所有图像文件,将它们的路径存储在一个列表中,并提取每个图像所属的类别名称。

首先,导入了四个Python库:os、PIL、random和pathlib。其中,os库提供了访问操作系统功能的接口;PIL库是Python图像处理库;random库提供了生成随机数的函数;pathlib库提供了以面向对象的方式操作文件路径的方法。

然后,定义了一个字符串变量"data_dir",它指定了包含图像数据集的目录。接着,使用pathlib.Path()方法将"data_dir"转换为Path对象类型,以便可以调用该对象上的方法执行一些操作。

接下来,使用data_dir.glob(‘*’)方法获取"data_dir"目录下的所有文件路径,并将这些路径存储到data_paths列表中。

在下一行代码中,使用列表推导式和split()方法从每个文件的路径中获取类别名称,并将这些名称存储到classeNames列表中。

最后,打印classeNames列表,即可查看所有类别的名称。

代码如下:

import os,PIL,random,pathlibdata_dir = './weather_photos/'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNames

2、transforms.Compose详解

transforms.Compose是PyTorch中一个用于数据预处理的类,它允许用户将多个数据预处理操作组合在一起。具体而言,transforms.Compose会接受多个transform函数作为输入,并返回一个新的transform函数,该函数将按照传入的顺序依次执行每个transform函数。

例如,假设我们想要对一张图片进行数据增强操作,包括随机裁剪、水平翻转和归一化。我们可以使用以下代码:

import torchvision.transforms as transforms# 定义数据增强操作
transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 对图片应用数据增强操作
img_augmented = transform(img)

在上面的示例中,我们首先定义了一个transform对象,其中包含四个数据预处理操作:随机裁剪、水平翻转、将图像转换为张量以及归一化。然后,我们将这个transform对象应用到一张图片上,得到了经过数据增强处理后的新图片。

需要注意的是,transforms.Compose只能用于序列化的数据预处理操作,也就是说每个操作必须能够被序列化为一个字符串,并且可以通过反序列化得到一个可用的操作函数。因此,如果你需要定义自己的数据预处理操作,并将其添加到transforms.Compose中,则必须确保这些操作是可序列化的。

3、图像处理(数据增强)

接下来编写用来进行图像处理的代码。

下面是每行代码的解释:

total_datadir = './weather_photos/'

设置数据集所在的目录路径。

train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

定义数据处理的步骤,这里采用了三个步骤:

  • 将图像 resize 到固定大小(224x224),保证输入模型的图片尺寸一致。
  • 将 PIL Image 或 numpy.ndarray 转换为 tensor,并归一化到 [0,1] 之间。
  • 进行标准化处理,将输入的图片转换为标准正太分布(高斯分布),使得模型更容易收敛。

其中,meanstd 是从数据集中随机抽样计算得到的均值和标准差。

4、加载数据集

加载数据集的代码如下:

total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)

使用 PyTorch 提供的 ImageFolder 类加载数据集,将数据集所在路径和上面定义好的数据处理方式传入。该类会自动将数据集按照文件名所在的文件夹分类,并将分类信息存储在 classes 属性中。此外,还有一个 class_to_idx 属性,保存了每个分类对应的索引,方便后续训练模型时使用。最终,该代码段返回处理好的数据集对象 total_data

5、划分数据集

使用PyTorch库中的torch.utils.data.random_split函数,将给定的数据集total_data按照指定比例(80%和20%)随机划分为训练集和测试集。

具体来说,首先计算出训练集大小train_size,即将总数据集大小乘以0.8,然后计算出测试集大小test_size,即总数据集大小减去训练集大小。接着,调用random_split函数,将总数据集和一个包含两个元素的列表作为参数传入,其中第一个元素是训练集的大小,第二个元素是测试集的大小。该函数会返回两个新的数据集对象,分别对应划分好的训练集和测试集。

代码如下:

# 计算训练集大小和测试集大小
train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size# 调用random_split函数进行随机划分
# total_data: 需要划分的原始数据集
# [train_size, test_size]: 包含两个元素的列表,分别指定训练集和测试集的大小
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])# 打印划分好的训练集和测试集
train_dataset, test_dataset

使用 PyTorch 的 DataLoader 对象来创建训练集和测试集的数据加载器,以便在训练和测试期间对数据进行批处理。

batch_size = 32  # 每次迭代使用的样本数# 创建训练集和测试集的数据加载器
train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,  # 在每个 epoch 开始之前打乱数据顺序num_workers=1)  # 使用一个工作线程加载数据test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1)

train_datasettest_dataset 应该是一个 PyTorch 数据集对象,包含了训练集和测试集的样本、标签等信息。shuffle 参数指定是否在每个 epoch 开始时打乱数据,num_workers 参数指定用于加载数据的工作线程数量。

这里我们创建了两个 DataLoader 对象:train_dltest_dl,分别用于加载训练集和测试集数据,并将其分为大小为 batch_size 的小批次(batch)。在训练过程中,可以使用这些小批次来计算损失函数并更新模型参数,从而逐步提高模型性能。

二、构建简单的CNN网络

卷积层、全连接层、池化层以及批量归一化层的详解可移步至PyTorch实战2:彩色图片识别(CIFAR10)

在此详细介绍torch.nn.Flatten()x.view()

1、torch.nn.Flatten()详解

torch.nn.Flatten() 是 PyTorch 中的一个层(Layer),它可以将多维张量(tensor)展平为一维。具体来说,该层可以将大小为 (batch_size, C, H, W) 的张量展平为大小为 (batch_size, C*H*W) 的张量。

下面是一个简单的例子:

import torch
import torch.nn as nnx = torch.randn(2, 3, 4, 5) # 生成一个大小为 (2, 3, 4, 5) 的随机张量
flatten = nn.Flatten()
y = flatten(x)print(f"x: {x.size()}") # 输出 x 的大小
print(f"y: {y.size()}") # 输出 y 的大小

输出结果如下:

x: torch.Size([2, 3, 4, 5])
y: torch.Size([2, 60])

从输出结果可以看出,x 的大小为 (2, 3, 4, 5),而 y 的大小为 (2, 60),即将 x 展平为一个大小为 (2, 60) 的张量。

2、x.view()详解

x.view() 是 PyTorch 中的一个函数,它可以将一个张量 x 重塑(reshape)成另外一个形状。与 torch.flatten() 不同,x.view() 可以将张量变换为任何形状,只要元素总数保持不变即可。

具体来说,如果 x 的形状为 (n1, n2, ..., nk),那么通过 x.view(a1, a2, ..., ak),可以将 x 变换为形状为 (a1, a2, ..., ak) 的新张量。需要注意的是,x 和变换后的新张量共享存储空间,因此在修改其中一个张量的值时,另一个张量的值也会相应地发生改变。

下面是一个简单的例子:

import torchx = torch.arange(24).view(2, 3, 4) # 生成一个大小为 (2, 3, 4) 的张量
y = x.view(6, 4) # 将 x 重塑为一个大小为 (6, 4) 的新张量print(f"x: {x}")
print(f"y: {y}")x[0, 1, 2] = -1 # 修改 x 中的一个元素print(f"x: {x}")
print(f"y: {y}")

输出结果如下:

x: tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])
y: tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11],[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]])
x: tensor([[[ 0,  1,  2,  3],[ 4,  5, -1,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])
y: tensor([[ 0,  1,  2,  3],[ 4,  5, -1,  7],[ 8,  9, 10, 11],[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]])

从输出结果可以看出,首先定义了一个大小为 (2, 3, 4) 的张量 x,然后通过 x.view(6, 4) 将其重塑为大小为 (6, 4) 的张量 y。接着,修改了 x 中的一个元素,可以看到 y 中相应位置的值也发生了改变,这是因为 xy 共享存储空间。

3、两者之间的区别

torch.nn.Flatten()x.view() 都可以用来将 tensor 扁平化,但是它们的使用方式和具体实现有所不同。

torch.nn.Flatten() 是一个模块,当在模型中调用时,它会自动将输入 tensor 扁平化。例如:

import torch.nn as nn
flatten = nn.Flatten()x = torch.randn(3, 4, 5)
y = flatten(x)
print(y.shape)  # 输出 torch.Size([60])

x.view() 是一个 tensor 的方法,可以用来将 tensor 重塑为指定的形状。例如:

x = torch.randn(3, 4, 5)
y = x.view(3, 20)
print(y.shape)  # 输出 torch.Size([3, 20])

虽然这两个函数都能够扁平化 tensor,但是它们的语义和用法有所不同。torch.nn.Flatten() 更适合在神经网络模型的构建中使用,而 x.view() 更适合在计算过程中手动调整 tensor 形状。

两者仅仅是一种数据集拉伸操作(将二维数据拉伸为一维),torch.flatten()方法不会改变x本身,而是返回一个新的张量。而x.view()方法则是直接在原有数据上进行操作。

  • 具体网络结构在此不做详细赘述,下一篇基于水质图像识别的水质评估会进行详细的描述
  • 另:模型训练与训练结果的可视化请移步至PyTorch实战1:实现mnist手写数字识别学习

4、案例结果展示

  • 训练过程:

在这里插入图片描述

  • 结果可视化

在这里插入图片描述

三、指定图片进行预测

1、⭐torch.squeeze()详解

torch.squeeze()是PyTorch中的一个函数,它用于移除张量中大小为1的维度。具体来说,如果张量在某个维度上只有一个元素,那么这个维度可以被“挤压”或者“挤掉”,从而生成一个新的形状更小的张量。

  • 对数据的维度进行压缩,去掉维数为1的的维度

函数原型:

torch.squeeze(input, dim=None, *, out=None) 

函数的参数包括:

  • input (Tensor):需要挤压的输入张量。
  • dim (int, optional):指定需要挤压的维度。如果不指定,将会挤压所有大小为1的维度。

函数返回一个新的张量。

以下是一些示例代码:

import torch# 示例1
x = torch.randn(2, 1, 3, 1, 4)
y = torch.squeeze(x)
print(y.shape) # 输出torch.Size([2, 3, 4])# 示例2
x = torch.randn(2, 1, 3, 1, 4)
y = torch.squeeze(x, dim=1)
print(y.shape) # 输出torch.Size([2, 3, 1, 4])

在第一个示例中,输入张量x的形状为(2, 1, 3, 1, 4),其中有两个大小为1的维度,分别位于第二和第四维。使用torch.squeeze(x),将移除这两个维度,输出的新张量y的形状为(2, 3, 4)

在第二个示例中,指定了需要挤压的维度为1,因此只移除了第二个维度上的大小为1的元素,并保留了其他维度。输出的新张量y的形状为(2, 3, 1, 4),仍然有一个大小为1的维度。

2、⭐torch.unsqueeze() 详解

torch.unsqueeze() 是 PyTorch 中用于增加张量维度的函数,可以将给定张量的维度增加一维。具体来说,它可以在指定位置插入一个大小为 1 的新维度。

  • 对数据维度进行扩充。给指定位置加上维数为一的维度

函数定义如下:

torch.unsqueeze(input, dim)

参数说明:

  • input (Tensor):输入张量。
  • dim (int):插入新维度的位置。该参数是可选的,默认为零,即在第一维插入新维度。

返回值:

  • 返回插入新维度后的新张量。

例如,对于一个形状为 (3, 4) 的二维张量 x,我们可以通过以下方式在第一维插入新维度:

import torchx = torch.randn(3, 4)
print(x.shape) # 输出 (3, 4)y = torch.unsqueeze(x, 0)
print(y.shape) # 输出 (1, 3, 4)

上述代码中,使用 torch.randn() 创建了一个形状为 (3, 4) 的二维张量 x。然后,使用 torch.unsqueeze(x, 0) 在第一维插入新维度,得到一个形状为 (1, 3, 4) 的三维张量 y。其中,第一维的大小为 1。

可以看到,torch.unsqueeze() 可以方便地增加张量的维度,这在神经网络中非常有用。例如,在图像识别任务中,通常需要将二维图像张量增加一个通道维度,以便于输入卷积层。

3、某张图片预测

代码流程:

  1. 使用PIL库加载待预测图片。
  2. 将图片按照指定的预处理方法transform转换成模型输入所需的Tensor格式。
  3. 在图片的第0维上添加一个新的维度,以满足模型输入为4维(batch_size, channels, height, width)的要求。
  4. 将图片Tensor输入到已训练的分类器模型中,获取模型输出。
  5. 从模型输出中找到最大值及其对应的索引,即为预测结果所属的类别。
  6. 根据类别索引在classes列表中查找对应的类别名称。
  7. 返回预测结果。

下面是代码解释:

from PIL import Image  # 导入PIL库,用于处理图像# 获取数据集中的类别列表
classes = list(total_data.class_to_idx)def predict_one_image(image_path, model, transform, classes):'''对单张图片进行分类预测参数:image_path:待预测图片路径model:已训练模型transform:数据预处理函数classes:数据集类别列表返回值:无'''# 打开待预测的图片,并转换为RGB格式test_img = Image.open(image_path).convert('RGB')# plt.imshow(test_img)  # 展示预测的图片# 对图片进行数据预处理test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)  # 将图片加入一个batch中,这里的batch size为1model.eval()  # 设置模型为评估模式output = model(img)  # 对图片进行分类预测_,pred = torch.max(output,1)  # 获取预测结果中概率最高的类别标签pred_class = classes[pred]  # 根据类别标签获取类别名称print(f'预测结果是:{pred_class}')  # 输出预测结果

需要注意的是,该代码中需要事先定义好模型的结构和参数,并将其加载到 GPU 上进行训练。同时也需要一个数据预处理函数(transform),用于将输入的图片转换为 PyTorch 可以处理的格式。

参数说明:

  • image_path: 待预测图片路径。
  • model: 训练好的分类器模型。
  • transform: 数据预处理方法,用于将图片转换成模型输入所需格式。
  • classes: 类别名称列表,记录模型预测结果对应的类别名称。

调用函数代码如下:

predict_one_image(image_path='./4-data/Monkeypox/M01_01_00.jpg',model=model, transform=train_transforms, classes=classes)

四、总结

本次学习主要涉及了PyTorch中CNN网络的搭建、数据处理和模型预测。在准备阶段,通过导入数据、使用transforms.Compose进行数据增强、加载数据集以及划分数据集等操作为构建CNN网络做好准备。

在构建CNN网络的过程中,介绍了torch.nn.Flatten()和x.view()两个函数,并讲解了它们之间的区别。最后,还学习了如何使用指定的图片进行预测,包括torch.squeeze()和torch.unsqueeze()的详解和某张图片的预测方法。

通过本次学习,可以更好地掌握PyTorch中CNN网络的构建和使用,从而提高自己的深度学习水平。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_104026.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python入门第五十四天】Python丨NumPy ufuncs

什么是 ufuncs? ufuncs 指的是“通用函数”(Universal Functions),它们是对 ndarray 对象进行操作的 NumPy 函数。 为什么要使用 ufuncs? ufunc 用于在 NumPy 中实现矢量化,这比迭代元素要快得多。 它们…

线程的生命周期以及sleep()方法和wait()方法

三种休眠状态:Blocked,Waiting,Timed_Waiting 注意两个Blocked态是不一样的,上面的Blocked只要睡眠时间到了马上进入运行态,下面处于Blocked的线程还需要抢到锁才能进入运行态 sleep()和wait()方法: sleep…

【hello Linux】进程间通信——匿名管道

目录 前言: 总结下上述的内容: 1. 进程间通信目的 2. 进程间通信的分类 1. 匿名管道 2. 匿名管道的使用 1. 匿名管道的创建 2. 使用匿名管道进行父子间通信 Linux🌷 前言: 进程具有独立性,拥有独立的数据、代码及其他…

论文阅读:PVO: Panoptic Visual Odometry

全景视觉里程计、同时做全景分割和视觉里程计 连接:PVO: Panoptic Visual Odometry 0.Abstract 我们提出了一种新的全景视觉里程计框架PVO,以实现对场景运动、几何和全景分割信息的更全面的建模。我们将视觉里程计(VO)和视频全景分割(VPS)在一个统一的…

STM32F4_SRAM中调试代码

目录 1. 在RAM中调试代码 2. STM32的三种存储方式 3. STM32的启动方式 4. 实验过程 通过上一节的学习,我们已经了解了SRAM静态存储器; 1. 在RAM中调试代码 一般情况下,我们在MDK中编写工程应用后,调试时都是把程序下载到芯片…

Java_异常

Java_异常 1.什么是异常 ​ 生活中的异常:感冒发烧、电脑蓝屏、手机死机等。 ​ 程序中的异常:磁盘空间不足、网络连接中断、被加载的资源不存在等。 ​ 程序异常解决办法:针对程序中非正常情况,Java语言引入了异常&#xff0…

注意力机制:基于Yolov5/Yolov7的Triplet注意力模块,即插即用,效果优于cbam、se,涨点明显

论文:https://arxiv.org/pdf/2010.03045.pdf 本文提出了可以有效解决跨维度交互的triplet attention。相较于以往的注意力方法,主要有两个优点: 1.可以忽略的计算开销 2.强调了多维交互而不降低维度的重要性,因此消除了通道和权…

日撸 Java 三百行day38

文章目录 说明day381.Dijkstra 算法思路分析2.Prim 算法思路分析3.对比4.代码 说明 闵老师的文章链接: 日撸 Java 三百行(总述)_minfanphd的博客-CSDN博客 自己也把手敲的代码放在了github上维护:https://github.com/fulisha-ok/…

VR全景图片,探究VR全景图片为何如此受欢迎?

随着科技的不断进步,虚拟现实技术逐渐渗透到我们的日常生活中,为我们带来了许多前所未有的体验和乐趣。而其中,VR全景图片作为一种基于虚拟现实技术的图片展示形式,不仅在旅游、房地产、教育等领域得到了广泛的应用,也…

c++强制类型转换:

强制类型转换:1. const属性用const_cast。 案例: 说明:该变量可以将变量的const 的属性去掉。如该案例,转换后修改x的值是合法的。2. 基本类型转换用static_cast。 案例: 说明:一般用在(1)基本类型&#xf…

学系统集成项目管理工程师(中项)系列10_立项管理

1. 系统集成项目管理至关重要的一个环节 2. 重点在于是否要启动一个项目,并为其提供相应的预算支持 3. 项目建议 3.1. Request for Proposal, RFP 3.2. 立项申请 3.3. 项目建设单位向上级主管部门提交的项目申请文件,是对拟建项目提出的总体设想 3…

基于centos7:Harbor-2.7.2部署和安装教程

基于centos7:Harbor-2.7.2部署和安装教程 1、软件资源介绍 Harbor是VMware公司开源的企业级DockerRegistry项目,项目地址为https://github.com/vmware/harbor。其目标是帮助用户迅速搭建一个企业级的Dockerregistry服务。它以Docker公司开源的registry…

WPF学习

一、了解WPF的框架结构 (第一小节随便看下就可以,简单练习就行) 1、新建WPF项目 xmlns:XML的命名空间 Margin外边距:左上右下 HorizontalAlignment:水平位置 VerticalAlignment:垂直位置 2…

Timer0/1设置时钟计算中断时间

时钟一般分为外部晶振时钟和内部时钟,相对而说,外部晶振时钟的精准度比内部系统时钟高,时间计算的更准。除非产品需要一般都不会用外部晶振时钟,因为好的东西贵啊,成本高。 本文主要介绍如何利用时钟设置Timer0/1&…

厨电新十年,不可逆的行业分化与老板电器的数字进化

“人生就像滚雪球,最重要之事是发现湿雪和长长的山坡。”股神巴菲特的这句名言,让坡是否长、雪是否厚成为人们评价一个行业、一家公司的标准之一。 家电行业,厨电曾是最后一块“坡长雪厚”之地,投资者也对相关企业给出了相当的热…

MySQL根据中文姓名排序查询

在MySQL中当说到进行排序查询时,大家的第一反应就是使用 ORDER BY 方法指定列进行排序,但是如果要指定列为中文数据按照首字母排序时,就会发现 ORDER BY 方法排序的顺序其实是有问题的。 我们先来测试下正常使用 ORDER BY 排序: 指…

35岁程序员被裁赔偿27万,公司又涨薪让我回去,前提是退还补偿金,能回吗?

在大多数人眼里,35岁似乎都是一道槛,互联网界一直都有着“程序员是吃青春饭”的说法,。 如果在35岁的时候被裁能获得27万的赔偿,公司又涨薪请你回去上班,你会怎么选? 最近,就有一位朋友在网上…

剑指 Offer 42. 连续子数组的最大和:C语言解法

剑指 Offer 42. 连续子数组的最大和 - 力扣(Leetcode) 输入一个整型数组,数组中的一个或连续多个整数组成一个子数组。求所有子数组的和的最大值。 要求时间复杂度为O(n)。 实例: 输入: nums [-2,1,-3,4,-1,2,1,-5,4] 输出: …

SOLIDWORKS认证考试流程

一、SOLIDWORKS认证考试前的准备工作 1、检查电脑硬件设备是否可以正常使用,如键盘鼠标等。 2、检查Solidworks软件是否可以正常使用。 3、关闭电脑所有杀毒软件。 4、检查电脑网络(外网)是否正常。 5.请联系我们获取考试系统软件安装包。…

Maven 下载及配置详细步骤

1、Maven 下载 Maven 官网地址:https://maven.apache.org/download.cgi(opens new window) 进入 Maven 官网,点击 archives 下载版本 3.6.2 找到下载的压缩包并解压