卷积神经网络--猫狗系列【VGG16】

news/2024/5/15 0:07:17/文章来源:https://blog.csdn.net/qq_53968319/article/details/131502829

数据集:【文末】

数据集预处理

定义读取数据辅助类(继承torch.utils.data.Dataset)

import osimport PILimport torchimport torchvisionimport matplotlib.pyplot as pltimport torch.utils.dataimport PIL.Image
# 数据集路径train_path = './train'test_path = './test'device = torch.device("cuda" if torch.cuda.is_available() else "cpu")class MyDataset(torch.utils.data.Dataset):    def __init__(self, data_path: str, train=True, transform=None):        self.data_path = data_path        self.train_flag = train        if transform is None:            self.transform = torchvision.transforms.Compose(                [                    torchvision.transforms.Resize(size=(224, 224)),  # 尺寸规范                    torchvision.transforms.ToTensor(),  # 转化为tensor                    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 归一化                ])        else:            self.transform = transform        self.path_list = os.listdir(data_path)  # 列出所有图片命名    def __getitem__(self, idx: int):        img_path = self.path_list[idx]        if self.train_flag is True:            # 例如 img_path 值 cat.10844.jpg -> label = 0            if img_path.split('.')[0] == 'dog':                label = 1            else:                label = 0        else:            label = int(img_path.split('.')[0])  # 获取test数据的编号        label = torch.tensor(label, dtype=torch.int64)  # 把标签转换成int64        img_path = os.path.join(self.data_path, img_path)  # 合成图片路径        img = PIL.Image.open(img_path)  # 读取图片        img = self.transform(img)  # 把图片转换成tensor        return img, label    def __len__(self) -> int:        return len(self.path_list)  # 返回图片数量train_datas = MyDataset(train_path)test_datas = MyDataset(test_path, train=False)

(原本数据有25000张,由于设备的原因,训练完之后我删掉了很多图片,训练集+测试集只有2000张)

查看读取的数据

# 展示读取的图片数据,因为做了归一化,所有图片显示不正常。Img_PIL_Tensor = train_datas[20][0]new_img_PIL = torchvision.transforms.ToPILImage()(Img_PIL_Tensor).convert('RGB')plt.imshow(new_img_PIL)plt.show(block=True)

训练集和测试集分组,数据分batch

(根据自己的设备来,好的就设32,不好就4吧)

# 70%训练集  30%测试集train_size = int(0.7 * len(train_datas))validate_size = len(train_datas) - train_sizetrain_datas,validate_datas = torch.utils.data.random_split(train_datas,[train_size, validate_size])# 数据分批# batch_size=32 每一个batch大小为32# shuffle=True 打乱分组# pin_memory=True 锁页内存,数据不会因内存不足,交换到虚拟内存中,能加快数据读入到GPU显存中.# num_workers 线程数。num_worker设置越大,加载batch就会很快,训练迭代结束可能下一轮batch已经加载好# win10 设置会多线程可能会出现问题,一般设置0.train_loader = torch.utils.data.DataLoader(train_datas, batch_size=4,                                            shuffle=True, pin_memory=True, num_workers=0)validate_loader = torch.utils.data.DataLoader(validate_datas, batch_size=4,                                            shuffle=True, pin_memory=True, num_workers=0)test_loader = torch.utils.data.DataLoader(test_datas, batch_size=4,                                            shuffle=False, pin_memory=True, num_workers=0)

VGG网络:

def vgg_block(num_convs, in_channels, out_channels):    layers = []    for _ in range(num_convs):        layers.append(torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))        layers.append(torch.nn.ReLU())        in_channels = out_channels    # ceil_mode=False 输入的形状不是kernel_size的倍数,直接不要。    # ceil_mode=True 输入的形状不是kernel_size的倍数,单独计算。    layers.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False))    return torch.nn.Sequential(*layers)def vgg(conv_arch):    conv_blks = []    # 数据输入是几个通道    in_channels = 3    # 卷积层部分    for (num_convs, out_channels) in conv_arch:        conv_blks.append(vgg_block(num_convs, in_channels, out_channels))        in_channels = out_channels    return torch.nn.Sequential(        *conv_blks, torch.nn.Flatten(),        torch.nn.Linear(out_channels * 7 * 7, 4096), torch.nn.ReLU(), torch.nn.Dropout(0.5),        torch.nn.Linear(4096, 4096), torch.nn.ReLU(), torch.nn.Dropout(0.5),        torch.nn.Linear(4096, 2))

VGG神经网络定义和参数初始化

# VGG11,VGG13,VGG16,VGG19 可自行更换。conv_arch = ((2, 64), (2, 128), (3, 256), (3, 512), (3, 512))  # vgg16#conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))  # vgg11#conv_arch = ((2, 64), (2, 128), (2 , 256), (2, 512), (2, 512))  # vgg13#conv_arch = ((2, 64), (2, 128), (4, 256), (4, 512), (4, 512))  # vgg19net = vgg(conv_arch)   # 定义网络net = net.to(device)   # 把网络加载到GPU上# Xavier方法 初始化网络参数,最开始没有初始化一直训练不起来。def init_normal(m):    if type(m) == torch.nn.Linear:        # Xavier初始化        torch.nn.init.xavier_uniform_(m.weight)        torch.nn.init.zeros_(m.bias)    if type(m) == torch.nn.Conv2d:        # Xavier初始化        torch.nn.init.xavier_uniform_(m.weight)        torch.nn.init.zeros_(m.bias)net.apply(init_normal)learn_rate = 1e-5#momentum = 0.9#optimizer = torch.optim.SGD(net.parameters(), learn_rate, momentum = momentum) #定义梯度优化算法optimizer = torch.optim.Adam(net.parameters(), learn_rate) #开始使用SGD没有训练起来,才更换的Adamcost = torch.nn.CrossEntropyLoss(reduction='sum')     # 定义损失函数,返回batch的loss和。print(net)    # 打印模型架构

训练VGG神经网络

epoch = 10  # 迭代10次def train_model(net, train_loader, validate_loader, cost, optimezer):    net.train()  # 训练模式    now_loss = 1e9  # flag 计算当前最优loss    train_ls = []  # 记录在训练集上每个epoch的loss的变化情况    train_acc = []  # 记录在训练集上每个epoch的准确率的变化情况    for i in range(epoch):        loss_epoch = 0.  # 保存当前epoch的loss和        correct_epoch = 0  # 保存当前epoch的正确个数和        for j, (data, label) in enumerate(train_loader):            data, label = data.to(device), label.to(device)            pre = net(data)            # 计算当前batch预测正确个数            correct_epoch += torch.sum(pre.argmax(dim=1).view(-1) == label.view(-1)).item()            loss = cost(pre, label)            loss_epoch += loss.item()            optimezer.zero_grad()            loss.backward()            optimezer.step()            if j % 100 == 0:                print(                    f'batch_loss:{loss.item()}, batch_acc:{torch.sum(pre.argmax(dim=1).view(-1) == label.view(-1)).item() / len(label)}%')        train_ls.append(loss_epoch / train_size)        train_acc.append(correct_epoch / train_size)        # 每一个epoch结束后,在验证集上验证实验结果。        with torch.no_grad():            loss_validate = 0.            correct_validate = 0            for j, (data, label) in enumerate(validate_loader):                data, label = data.to(device), label.to(device)                pre = net(data)                correct_validate += torch.sum(pre.argmax(dim=1).view(-1) == label.view(-1)).item()                loss = cost(pre, label)                loss_validate += loss.item()            # print(f'validate_sum:{loss_validate},  validate_Acc:{correct_validate}')            print(f'validate_Loss:{loss_validate / validate_size},  validate_Acc:{correct_validate / validate_size}%')            # 保存当前最优模型参数            if now_loss > loss_validate:                now_loss = loss_validate                print("保存模型参数。。。。。。。。。。。")                torch.save(net.state_dict(), 'model.params')    # 画图    plt.plot(range(epoch), train_ls, color='b', label='loss')    plt.plot(range(epoch), train_acc, color='g', label='acc')    plt.legend()    plt.show(block=True)  # 显示 labletrain_model(net, train_loader, validate_loader, cost, optimizer)

资料分享栏目

数据集之猫狗系列(VGG16)

链接:https://pan.baidu.com/s/1MoJPs-BQ6GP1PrXjo-wKsQ

提取码:dgna

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_325502.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nohup命令解决SpringBoot/java -jar命令启动项目运行一段时间自动停止问题

问题描述: 在centos7上部署多个springcloud项目。出现了服务莫名其妙会挂掉一两个的问题,重新启动挂掉的服务之后又会出现其他服务挂掉的情况,查看启动日志也并没有发现有异常抛出。令人费解的是所有的服务都是通过nohup java -jar xxx.jar …

强化学习路径优化:基于Q-learning算法的机器人路径优化(MATLAB)

一、强化学习之Q-learning算法 Q-learning算法是强化学习算法中的一种,该算法主要包含:Agent、状态、动作、环境、回报和惩罚。Q-learning算法通过机器人与环境不断地交换信息,来实现自我学习。Q-learning算法中的Q表是机器人与环境交互后的…

图像视频基础

图像视频基础 文章目录 图像视频基础图像颜色深度分辨率 视频帧率比特率帧类型 YUV模型色度子采样 图像 颜色深度 存储颜色的强度,需要占用一定大小的数据空间,这个大小被称为颜色深度。假如每个颜色的强度占用 8 bit(取值范围为 0 到 255&…

nginx+tomcat负载均衡和动静分离

目录 1.部署nginx 2.部署两台tomcat 3.配置nginx 1.部署nginx vim /vim/lib/systemd/system/nginx.service 2.部署两台tomcat 进入第一台装第一个tomcat vim /etc/profile vim /usr/local/tomcat/webapps/test/index.jsp 重启 进入第二台安装第二台tomcat vim /usr/local/tom…

(0021) H5-Vuejs配合 mint-ui 开发移动端web

mint-ui 初衷 element-ui主打pcweb,导致移动端上UI适配问题突出,趟了很多坑。这次更加理智些,选择了饿了么团队的主打移动端的mint-ui,目前来说体验很好。 认识Mint-ui 首先在手机上体验其demo,扫描链接:…

在 Jetpack Compose 中创建 Drawer

Jetpack Compose 是一个现代的构建 Android UI 的工具集,它使得构建 UI 变得更加简单快速。在本篇博客中,我们将讨论如何在 Jetpack Compose 中创建 Drawer,也就是我们常见的侧边抽屉。 什么是 Drawer? Drawer 是一个提供导航选项…

基于Transformer视觉分割综述

基于Transformer视觉分割综述 SAM (Segment Anything )作为一个视觉的分割基础模型,在短短的 3 个月时间吸引了很多研究者的关注和跟进。如果你想系统地了解 SAM 背后的技术,并跟上内卷的步伐,并能做出属于自己的 SAM…

GC回收器演进之路

目录 未来演进方向 历经之路 引用计数法 标记清除法 复制法 标记整理 分代式 三色标记法的诞生 三色标记法的基本概念 产生的问题 问题 1:浮动垃圾 问题 2:对象消失 遍历对象图不需要 STW 的解决方案 屏障机制 插入屏障(Dijks…

Autosar诊断系列介绍17 - 物理寻址及功能寻址详解

本文框架 前言1. 物理寻址及功能寻址基本概念1.1物理寻址及功能寻址-定义1.2两种寻址方式区别1.3不同诊断服务寻址方式配置 2.不同寻址方式的应用场景 前言 UDS(Unified Diagnostic Services)协议,即统一的诊断服务,是面向整车所…

基于SQLI的SQL字符型报错注入

基于SQLI的SQL字符型报错注入 一. 实验目的 理解数字型报错SQL注入漏洞点的定位方法,掌握利用手工方式完成一次完整SQL注入的过程,熟悉常见SQL注入命令的操作。 二. 实验环境 渗透主机:KALI平台 用户名: college 密码: 360College 目标网…

JAVA麻将胡牌算法深度解析

目录 麻将的基本概念 麻将牌的构成 麻将的碰,杠,吃,听,胡 麻将胡牌条件 胡牌算法简介 选将拆分法 算法数据结构 构建数据结构 数据结构使用 牌花色的获取 获取某一花色的牌值 获取某一张牌相邻牌 算法代码实现 基础代…

Web3.0 应用开发:选择合适的框架和工具至关重要

随着 Web3.0 时代的到来,区块链技术的普及和应用让去中心化的应用开发变得更加可行。然而,要开发出高效、稳定和安全的 Web3.0 应用,选择合适的框架和工具至关重要。本文将介绍 Web3.0 应用开发的关键因素,帮助开发者做出明智的选…

Hive Metastore 表结构

Hive MetaStore 的ER 图如下。 部分表结构和说明。 CTLGS(CATALOGS) catalogs 可以隔离元数据。默认只有1行。一个 CATALOG 可以有多个数据库。 mysql> DESC CTLGS; -------------------------------------------------------- | Field | Type | Null |…

海康明眸设备SDK二次开发NET_DVR_SetupAlarmChan_V41老是报109错误

请仔细阅读图2中的文件,这里详细介绍了怎么样 放置DLL,务必按照图3中的说明步骤进行放置。HCNetSDKCom文件夹一定也要拷贝到debug目录,否则就会出现类似于109的错误提示。

NR 吞吐量测试

前言 参考文档: 5G NR TBS (Transport Block size) Calculator | 5G-Tools.com 5G NR Transport Block Size (TBS) Calculation - Techplayon 5G MCS _ 搜索结果_哔哩哔哩_Bilibili 4/5G无线资源和数据调度流程:CQI上报、基站AMC调度、调度信息DCI下发、CQI到MCS的对…

网联V2X视频事件检测相机使用说明书

1 产品概览 网联 V2X视频事件检测相机 视频事件检测相机 ,内置 1/1.8″逐行扫描 800万像素传感器;视 万像素传感器;视 频编码协议支持 H.265、H.264、MJPEG;具有 1个 10M/100M/1000M自适应以 太网 RJ45接口、 1路 RS485接口&#…

Windows基于WSL搭建Python数据分析环境

最近配置了一台较为不错的台式机,记录下自己配置环境的过程。 安装WSL,提供Linux环境 如果你发现后续的命令无法运行或者说软件商城中找不到,这可能意味着你的操作系统不符合要求。WSL安装要求 Windows 10 version 2004(Build 19…

Go程序结构- package和import

1、包和文件 在Go语言中包的作用和其他语言中的库或模块的作用类似,用于支持模块化、封装、编译隔离和重用。关键点如下: (1)包中保存一个或者多个.go结尾的文件,而包的目录就是包的导入路径 (2)中Go中通过一条简单的规则来管理标识符是否对外…

SpringBoot 3.1 新版HTTP调用

在SpringBoot3版本发布后 官方便声明了推荐使用了内置声明式的HTTP客户端。 一、声明式HTTP客户端使用(依赖引入) <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-webflux</artifactId></depende…

Gradio库中的HighlightedText组件

❤️觉得内容不错的话&#xff0c;欢迎点赞收藏加关注&#x1f60a;&#x1f60a;&#x1f60a;&#xff0c;后续会继续输入更多优质内容❤️ &#x1f449;有问题欢迎大家加关注私戳或者评论&#xff08;包括但不限于NLP算法相关&#xff0c;linux学习相关&#xff0c;读研读博…