pytorch 入门 (四)案例二:人脸表情识别-VGG16实现

news/2024/5/18 19:35:20/文章来源:https://blog.csdn.net/qq_33489955/article/details/133993996

实战教案二:人脸表情识别-VGG16实现

本文为🔗小白入门Pytorch内部限免文章
参考本文所写记录性文章,请在文章开头注明以下内容,复制粘贴即可

  • 🍨 本文为🔗小白入门Pytorch中的学习记录博客
  • 🍦 参考文章:【小白入门Pytorch】人脸表情识别-VGG16实现
  • 🍖 原作者:K同学啊

数据集下载:
链接:https://pan.baidu.com/s/1RvlpOx8v6MudY65Oi78-kQ?pwd=zhfo
提取码:zhfo
–来自百度网盘超级会员V4的分享

目录

  • 实战教案二:人脸表情识别-VGG16实现
    • 一、导入数据
    • 二、VGG-16算法模型
      • 1. 优化器与损失函数
      • 2. 模型的训练
    • 三、可视化

一、导入数据

from torchvision.datasets   import CIFAR10 # CIFAR10是一个用于计算机视觉的经典数据集,其中包含60000张32x32的彩色图像,分为10个类别,每个类别有6000张图像。
from torchvision.transforms import transforms # 这是一个常用的模块,用于图像的预处理和增强。
from torch.utils.data       import DataLoader # 可以将数据集转化为迭代器的工具,方便在训练循环中加载数据。
from torchvision            import datasets # 导入了torchvision下的所有数据集,但实际上这与前面导入CIFAR10是重复的,可能是不必要的。
from torch.optim            import Adam # 导入了Adam优化器。Adam是一个常用的、表现良好的深度学习优化器。
import torchvision.models   as models # 这个模块提供了各种预训练模型,例如ResNet、VGG、DenseNet等。
import torch.nn.functional  as F # 提供了各种激活函数、损失函数和其他的功能函数。
import torch.nn             as nn # 这个模块提供了构建神经网络所需的各种工具,如层、损失函数等。
import torch,torchvision # torch是PyTorch的核心库,提供了基础的张量操作;torchvision则是与计算机视觉相关的库,提供了数据集、预处理方法和预训练模型。
train_datadir = '/home/mw/input/kzb324321357/2-Emotion_Images/2-Emotion_Images/train'
test_datadir  = '/home/mw/input/kzb324321357/2-Emotion_Images/2-Emotion_Images/test'train_transforms = transforms.Compose([transforms.Resize([48, 48]),    # 将输入图片resize成统一尺寸transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])test_transforms = transforms.Compose([transforms.Resize([48, 48]),    # 将输入图片resize成统一尺寸transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])# 使用 datasets.ImageFolder 加载训练数据集和测试数据集
# ImageFolder假定所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类别的名字。
# 同时,为加载的数据应用了之前定义的预处理流程。
train_data = datasets.ImageFolder(train_datadir, transform=train_transforms)
test_data = datasets.ImageFolder(test_datadir, transform=test_transforms)

torch.utils.data.DataLoader详解

torch.utils.data.DataLoader是Pytorch自带的一个数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。

函数原型:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False, pin_memory_device=‘’)

参数说明:

  • dataset(string) :加载的数据集
  • batch_size (int,optional) :每批加载的样本大小(默认值:1)
  • shuffle(bool,optional) : 如果为True,每个epoch重新排列数据。
  • sampler (Sampler or iterable, optional) : 定义从数据集中抽取样本的策略。 可以是任何实现了 len 的 Iterable。 如果指定,则不得指定 shuffle 。
  • batch_sampler (Sampler or iterable, optional) : 类似于sampler,但一次返回一批索引。与 batch_size、shuffle、sampler 和 drop_last 互斥。
  • num_workers(int,optional) : 用于数据加载的子进程数。 0 表示数据将在主进程中加载(默认值:0)。
  • pin_memory (bool,optional) : 如果为 True,数据加载器将在返回之前将张量复制到设备/CUDA 固定内存中。 如果数据元素是自定义类型,或者collate_fn返回一个自定义类型的批次。
  • drop_last(bool,optional) : 如果数据集大小不能被批次大小整除,则设置为 True 以删除最后一个不完整的批次。 如果 False 并且数据集的大小不能被批大小整除,则最后一批将保留。 (默认值:False)
  • timeout(numeric,optional) : 设置数据读取的超时时间 , 超过这个时间还没读取到数据的话就会报错。(默认值:0)
  • worker_init_fn(callable,optional) : 如果不是 None,这将在步长之后和数据加载之前在每个工作子进程上调用,并使用工作 id([0,num_workers - 1] 中的一个 int)的顺序逐个导入。 (默认:None)
# 创建训练数据加载器(data loader),用于将数据分成小批次进行训练
train_loader = torch.utils.data.DataLoader(train_data,batch_size=16,      # 每个批次包含的图像数量shuffle=True,       # 随机打乱数据num_workers=4)      # 使用多少个子进程来加载数据# 创建测试数据加载器(data loader),用于将测试数据分成小批次进行测试
test_loader = torch.utils.data.DataLoader(test_data,batch_size=16,      # 每个批次包含的图像数量shuffle=True,       # 随机打乱数据num_workers=4)      # 使用多少个子进程来加载数据# 打印数据集的信息
# 请注意,这里使用len(train_loader) * 16来计算图像总数是基于批次大小为16的假设。
# 实际上,最后一个批次的图像数量可能少于16。
print("The number of images in a training set is: ", len(train_loader) * 16)  # 计算训练集中的图像总数
print("The number of images in a test set is: ", len(test_loader) * 16)      # 计算测试集中的图像总数
print("The number of batches per epoch is: ", len(train_loader))             # 计算每个 epoch 中的批次数# 定义数据集的类别标签
classes = ('Angry', 'Fear', 'Happy', 'Surprise')
The number of images in a training set is:  18480
The number of images in a test set is:  2320
The number of batches per epoch is:  1155

二、VGG-16算法模型

device = "cuda" if torch.cuda.is_available() else "cpu"print("Using {} device".format(device))# 直接调用官方封装好的VGG16模型
model = models.vgg16(pretrained = True)
model
Using cuda device

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/mw/.cache/torch/hub/checkpoints/vgg16-397923af.pth

HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

1. 优化器与损失函数

optimizer = Adam(model.parameters(),lr = 0.0001,weight_decay = 0.0001)
loss_model = nn.CrossEntropyLoss()
import torch
from torch.autograd import Variable
# 定义训练函数
def train(model,train_loader,loss_model,optimizer):# 将模型移动到指定设备(如:GPU)model = model.to(device)# 将模型设置为训练模式(启用梯度计算)model.train()for i,(images,labels) in enumerate(train_loader,0):# 将输入数据和标签移动到指定设备images = Variable(images.to(device))labels = Variable(labels.to(device))# 梯度清零optimizer.zero_grad()# 前向传播得到模型输出outputs = model(images)# 计算损失loss = loss_model(outputs,labels)# 反向传播loss.backward()# 更新模型参数optimizer.step()# 每隔1000个批次输出一次损失if i%1000 == 0:print('[%5d] loss: %.3f' % (i,loss))# 定义测试函数
def test(model,test_loader,loss_model):# 获取测试数据集大小size = len(test_loader.dataset)# 获取测试数据批次数num_batches = len(test_loader)# 将模型设置为评估模式(不进行梯度计算)model.eval()test_loss,correct = 0,0# 在不计算梯度的上下文中执行测试with torch.no_grad():for X,y in test_loader:X,y = X.to(device),y.to(device)# 使用模型进行前向传播得到预测pred = model(X)# 计算损失值,并将其累积到test_loss中test_loss += loss_model(pred,y).item()# 计算正确预测的数量correct += (pred.argmax(1)==y).type(torch.float).sum().item()# 计算平均测试损失和准确率test_loss /= num_batchescorrect /= sizeprint(f"Test Error:\n Accuracy:{(100*correct):>0.1f}%,Avg loss:{test_loss:>8f} \n")return correct,test_loss

提问:梯度清零有三种放置方法吗,放到循环最前面或者最后面,或者反向传播之前

是的,您理解得很准确。在PyTorch中,optimizer.zero_grad()的位置有一定的灵活性。以下是这三种常见的放置方法:

  1. 循环开始时:在每次迭代的最开始,即前向传播之前,清零梯度。
  2. 反向传播之前:在前向传播之后、反向传播之前,清零梯度。这也是您提供的代码中使用的方法。
  3. 循环结束时:在每次迭代的最后,即更新参数之后,清零梯度。

这三种方法在大多数情况下都是等效的,因为关键是确保在进行下一次反向传播之前梯度是清零的。

但是,如果您在一个迭代中进行多次反向传播(例如,当您想要累加梯度时),那么您必须在每次反向传播之前清零梯度。在这种特定情境中,第二种方法(反向传播之前)是最合适的。

对于大多数常规的训练循环,选择哪种方法主要是根据个人偏好或代码的可读性来决定的。

2. 模型的训练

# 创建一个空列表用于存储每个epoch的测试集准确率
test_acc_list = []
# 定义训练的总论数
epochs = 10# 开始训练循环,每个epoch 都会执行一下操作
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")# 在训练数据上训练模型train(model,train_loader,loss_model,optimizer)# 在测试数据集上测试模型的性能,并获取测试准确率和测试损失test_acc,test_loss = test(model,test_loader,loss_model)# 将测试准确率添加到列表中,以便后续分析test_acc_list.append(test_acc)# 所有epoch完成后打印完成消息
print("Done!")
Epoch 1
-------------------------------
[    0] loss: 0.129
[ 1000] loss: 0.005
Test Error:Accuracy:77.4%,Avg loss:1.069592 Epoch 2
-------------------------------
[    0] loss: 0.028
[ 1000] loss: 0.055
Test Error:Accuracy:78.7%,Avg loss:0.976879 Epoch 3
-------------------------------
[    0] loss: 0.033
[ 1000] loss: 0.050
Test Error:Accuracy:77.9%,Avg loss:1.202651 Epoch 4
-------------------------------
[    0] loss: 0.051
[ 1000] loss: 0.356
Test Error:Accuracy:79.0%,Avg loss:1.080943 Epoch 5
-------------------------------
[    0] loss: 0.001
[ 1000] loss: 0.183
Test Error:Accuracy:78.7%,Avg loss:1.248081 Epoch 6
-------------------------------
[    0] loss: 0.003
[ 1000] loss: 0.127
Test Error:Accuracy:78.4%,Avg loss:1.129110 Epoch 7
-------------------------------
[    0] loss: 0.003
[ 1000] loss: 0.076
Test Error:Accuracy:77.6%,Avg loss:1.200314 Epoch 8
-------------------------------
[    0] loss: 0.042
[ 1000] loss: 0.071
Test Error:Accuracy:78.0%,Avg loss:1.149877 Epoch 9
-------------------------------
[    0] loss: 0.002
[ 1000] loss: 0.212
Test Error:Accuracy:78.0%,Avg loss:1.353625 Epoch 10
-------------------------------
[    0] loss: 0.001
[ 1000] loss: 0.001
Test Error:Accuracy:78.5%,Avg loss:1.249242 Done!
test_acc_list
[0.773552290406223,0.7869490060501296,0.7791702679343129,0.7904062229904927,0.7869490060501296,0.783923941227312,0.7757130509939498,0.780466724286949,0.780466724286949,0.7852203975799481]

三、可视化

import numpy as np
import matplotlib.pyplot as pltx = [i for i in range(1,11)]plt.plot(x,test_acc_list,label="line ACC",alpha = 0.8)plt.xlabel("epoch")
plt.ylabel("acc")plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_187579.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

系统性认知网络安全

前言:本文旨在介绍网络安全相关基础知识体系和框架 目录 一.信息安全概述 信息安全研究内容及关系 信息安全的基本要求 保密性Confidentiality: 完整性Integrity: 可用性Availability: 二.信息安全的发展 20世纪60年代&…

Linux tmux使用总结

文章目录 1 tmux介绍2 tmux概念会话Sessions、窗口Windows、面板Panesstatus line中字段含义 3 Sessions会话管理新建会话断开当前会话进入之前的会话关闭会话查看所有的会话 4 tmux快捷指令系统指令窗口(Windows)指令面板(Panes)…

免费领取!TikTok Shop “全托管”黑五大促官方备战指南来啦!

黑五网一大促即将来袭,自“全托管”模式上线以来,TikTok for Business在沙特阿拉伯和英国市场开展了古尔邦节大促、夏季大促、返校季大促等活动,今年更是会借着黑五网一大促之际,首次覆盖美国市场,为全托管商家带来全球…

什么是接口测试?三分钟带你全面认识接口测试、带你学会接口测试~

目录 1、接口是什么? 2、接口的类型 3、接口测试初识 3.1、什么是接口测试 3.2、原理 3.3、特点 3.4、什么是自动化接口测试 4、接口测试流程 5、传统风格接口与RESTful风格接口 6、接口文档 6.1、什么是接口文档 6.2、接口文档作用 6.3、展现形式 6.4…

formData对象打印不出来

用el-upload上传图片 以流的形式传给后台 所以用formData对象带数据 let formData new FormData() formData.append(name,monkey7) console.log(formData) 明明已经把数据append进去了 console.log在控制台却打印不出 后来发现他得用formData.get("xxx"…

最全的图床集合(国内外,站长必备)

“heosu每月不定时更新嗷,防止错过消息推送,建议小伙伴添加到星标⭐喔” 为了减少服务器的压力不少站长还是选择图床存放图片的。所以就搜集一些比较好用的免费的图床(收费的在最后标出)以及我目前在用的图床。 为什么需要图床&am…

Linux系统CH347应用—SPI功能

Linux/安卓系统使用CH347转接SPI功能有三种应用方式: 1. 使用CH34X_MPHSI_Master总线驱动为系统扩展原生SPI Master,此方式无需进行单独的应用层编程; 2. 使用CH341PAR_LINUX字符设备驱动,此方式需要配合使用厂商提供的库文件&a…

【springcloud-config】配置中心客户端导入依赖spring-cloud-config-server后,maven一直爆红问题解决

问题描述 配置中心客户端导入了 spring-cloud-config-server 后&#xff0c;导入依赖爆红&#xff1b; 解决办法&#xff1a; 参考官网中文文档&#xff1a;spring-cloud -config 配置中心 中文文档 补充导入 spring-config-starter-config 配置即可 <!--springcloud-c…

跨境商城源码可以支持多种营销推广方式吗?

一、多种营销推广方式的重要性 跨境商城源码作为现代电商领域的重要工具&#xff0c;其支持多种营销推广方式对于吸引用户、增加销量以及提升品牌影响力都至关重要。通过采用多种营销推广方式&#xff0c;商家可以全方位地宣传和推广产品&#xff0c;吸引更多的潜在顾客&#x…

JS多选答题时,选项互斥时的情况

在做答题类的项目时&#xff0c;应该会比较常见多选题选相互斥的问题&#xff0c;例如&#xff1a; 你喜欢什么颜色&#xff1f;&#xff08;&#xff09;A、红色B、紫色C、蓝色D、灰色E、均无如该题&#xff0c;当选择选项E时&#xff0c;明显与其他选项互斥。这个时候经常会…

SAP 公司间销售

一、 概述 很多项目中&#xff0c;特别是集团型公司&#xff0c;生产总部在某地&#xff0c;但是在各个省会城市&#xff0c;乃至国外都有相应的贸易公司&#xff0c;特别是国外&#xff0c;此时贸易公司接到客户采购订单&#xff0c;但是贸易公司没有库存&#xff0c;甚至没有…

Zoho Mail荣登福布斯小型企业企业邮箱排行榜

在过去的数十载里&#xff0c;电子邮件已成为电子通信领域中不可或缺的一环&#xff0c;而在未来的岁月里&#xff0c;它有望继续在全球范围内普及应用。尽管如今市场上有许多免费的企业邮箱供用户和企业选用&#xff0c;但其中许多产品在特定场景下的专业化功能尚显不足&#…

selenium多窗口、多iframe切换、alert切换

多标签/多窗口之间的切换 场景&#xff1a; 在页面操作过程中有时候点击某个链接会弹出新的窗口&#xff0c;这时就需要切换到新打开的窗口上进行操作。这种情况下&#xff0c;需要识别多标签或窗口的情况。 操作方法&#xff1a; switch_to.window()方法&#xff1a;切换窗口…

如何高效的开展app的性能测试?

APP性能测试是什么 从网上查了一下&#xff0c;貌似也没什么特别的定义&#xff0c;我这边根据自己的经验给出一个自己的定义&#xff0c;如有巧合纯属雷同。 客户端性能测试就是&#xff0c;从业务和用户的角度出发&#xff0c;设计合理且有效的性能测试场景&#xff0c;制定…

js给一段话,遇到的第一个括号处加上换行符

list.forEach((item,index0)>{const productName item.name;const index productName.indexOf(&#xff08;);if (index -1) {return productName;}const before productName.slice(0, index);const after productName.slice(index);item.namebefore \n after;});

吃透Spring源码分析专题

想说的话 本人在互联网摸爬滚打至今(23年)6年了&#xff0c;平时有写博客的习惯&#xff0c;这个习惯是从大学的时候开始的&#xff0c;目前主要关注java领域相关的技术&#xff0c;python也有涉及&#xff0c;写Spring专题是因为Spring确实很重要&#xff0c;在目前这个开发模…

【C++】二叉树进阶 -- 详解

一、二叉搜索树概念 二叉搜索树 又称二叉排序树&#xff0c;它或者是一棵空树&#xff0c;或者是具有以下性质的二叉树&#xff1a; 若它的左子树不为空&#xff0c;则左子树上所有节点的值都小于根节点的值 若它的右子树不为空&#xff0c;则右子树上所有节点的值都大于根节点…

AI驱动的图纸数据提取

推荐&#xff1a;用 NSDT编辑器 快速搭建可编程3D场景 你是否曾经需要组合来自两个不同来源&#xff08;例如图像和文本&#xff09;的对象数据&#xff1f; 我们在工作的过程中经常面临这样的挑战。 在这里&#xff0c;我们展示了技术绘图领域的一个示例。 此类图纸用于许多领…

Python数据挖掘 | 升级版自动查核酸

&#x1f4d5;作者简介&#xff1a;热爱跑步的恒川&#xff0c;致力于C/C、Java、Python等多编程语言&#xff0c;热爱跑步&#xff0c;喜爱音乐的一位博主。 &#x1f4d7;本文收录于恒川的日常汇报系列&#xff0c;大家有兴趣的可以看一看 &#x1f4d8;相关专栏C语言初阶、C…

Mac电脑怎么在Dock窗口预览,Dock窗口预览工具DockView功能介绍

DockView是一款Mac电脑上的软件&#xff0c;它可以增强Dock的功能&#xff0c;让用户更方便地管理和切换应用程序。 DockView的主要功能是在 DockQ&#xff0c;栏上显示每个窗口的缩略图&#xff0c;并提供了一些相关的操作选项。当用户将鼠标悬停在Dock栏上的应用程序图标上时…