Pytorch网络模型训练

news/2024/5/5 8:27:59/文章来源:https://blog.csdn.net/qq_47896523/article/details/134203988

现有网络模型的使用与修改

vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别print(vgg16_true)

以下是vgg16预训练模型的输出 

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

预训练模型的输出从1000类别转为10类别

import torchvision
from torch import nn
# 因为数据集过大,所以注释掉此行代码
# train_data = torchvision.datasets.ImageNet("./data_image_net", split='train', download=True,
#                                            transform=torchvision.transforms.ToTensor())vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别print(vgg16_true)# vgg16_true.add_module("add_linear", nn.Linear(1000, 10))
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
# 在预训练模型的最后添加了一个新的全连接层,用于将最后的输出转化为10个类别
print(vgg16_true)print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
# 未预训练模型的最后一层的输出特征数更改为了10
print(vgg16_false)

网络模型的保存与读取

加载未预训练的模型

vgg16 = torchvision.models.vgg16(pretrained=False)

方式一

# 保存方式1  保存的模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pyth")#读取方式1
model = torch.load("vgg16_method1.pth")

方式二

# 保存方式2  不再保存模型结构,而是保存模型的参数为字典结构    推荐
torch.save(vgg16.state_dict(), "vgg16_method2.pyth")# 方式2,加载模型
# model = torch.load("vgg16_method2.pth")     #这样输出的是字典类型
# print(model)
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))      # 将其恢复为网络模型
print(vgg16)

完整的模型训练套路

准备数据集

# 准备数据集
train_data = torchvision.datasets.CIFAR10("../data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),download=True)train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为{}".format(train_data_size))    # 50000
print("测试数据集的长度为{}".format(test_data_size))     # 10000# 利用Dataloader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

创建网络模型

# 创建网络模型  神经网络的代码在train_module文件
tudui = Tudui()

train_module文件

# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()# 简化操作,并且按顺序进行操作self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x

构建损失函数

# 损失函数
loss_fn = nn.CrossEntropyLoss()

构建优化器

# 优化器
# 如果学习率过大,模型可能会在最小值附近震荡而无法收敛;如果学习率过小,模型训练可能会过于缓慢
learning_rate = 0.01
# 使用随机梯度下降算法来更新模型的权重
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

设置训练集参数

# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

添加tensorboard

# 将数据写入 TensorBoard 可视化的日志文件中
writer = SummaryWriter("../logs_train")

训练步骤

# tudui.train()
for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad()# 将优化器中的梯度缓存(如果有的话)清零loss.backward()# 计算损失函数(loss)相对于模型参数的梯度optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:# .item()是将tensor张量变为正常的数字print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))# loss.item()是当前步骤的损失值writer.add_scalar("train_loss", loss.item(), total_train_step)# 使用add_scalar可以将一个标量添加到之前的所有标量值中,# 这样就可以在TensorBoard中绘制一个标量随时间变化的图表

测试步骤

# 测试步骤开始
# tudui.eval()
total_test_loss = 0
total_accuracy = 0
# 不会对以下的代码进行调优
with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()# argmax(1)是横向看,argmax(0)是纵向看accuracy = (outputs.argmax(1) == targets).sum()# argmax在找到模型预测的最大概率对应的类别# 预测正确的个数total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))
print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))
# 测试集上的总损失
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1

利用GPU训练

# 定义训练的设备
# device = torch.device("cpu/cuda")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# GPU训练的关键
tudui = tudui.cuda()
# tudui = tudui.to(device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_190631.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【FastCAE源码阅读6】C++与Python的集成,实现相互调用

分析FastCAE代码之前先看看C与Python如何相互调用的。 一、C调用Python 先写个C调用Python的例子&#xff0c;然后再来看FastCAE集成Python就比较简单了。直接上代码&#xff1a; #include <iostream> #include "python.h"int main() {Py_Initialize();PyRu…

uniapp+uview2.0+vuex实现自定义tabbar组件

效果图 1.在components文件夹中新建MyTabbar组件 2.组件代码 <template><view class"myTabbarBox" :style"{ backgroundColor: backgroundColor }"><u-tabbar :placeholder"true" zIndex"0" :value"MyTabbarS…

关闭EasyConnect进程详细步骤

1、不关闭导致的问题 nacos浏览器可以正常访问&#xff0c;但idea启动的时候连不上nacos&#xff0c;而且第二次启动都启动不了&#xff0c;一直卡在那里&#xff0c;排查了半天&#xff0c;怀疑是装的EasyConnect的VPN导致的&#xff0c;于是停止掉相关服务即可。但直接结束进…

Git 基础知识回顾及 SVN 转 Git 自测

背景 项目开发过程中使用的版本控制工具是 SVN&#xff0c;Git 多有耳闻&#xff0c;以前也偶尔玩过几次&#xff0c;但是工作中不用&#xff0c;虽然本地也有环境&#xff0c;总是不熟练。 最近看一本网络开源技术书时&#xff0c;下载源码部署了一下&#xff0c;又温故了一…

js调整table表格上下相邻元素顺序

有时候我们会遇到要通过箭头控制table表格上下顺序的需求,如下: 点击向下就将该元素下移一位,下面的一位元素就移上来,点击向上就将该元素上移一位,上面的一位元素就移下来,也就是相邻元素互换位置顺序: <el-table :data="targetTable" border style=&quo…

Sui发布RPC2.0 Beta,拥抱GraphQL并计划弃用JSON-RPC

为了解决现有RPC存在的许多已知问题&#xff0c;Sui正在准备推出一个基于GraphQL的新RPC服务&#xff0c;名为Sui RPC 2.0。GraphQL是一种开源数据查询和操作语言&#xff0c;旨在简化需要复杂数据查询的API和服务。 用户目前可以访问Sui主网和测试网网络的Beta版本的只读快照…

nacos的部署与配置中心

文章目录 一、nacos部署安装的方式单机模式:集群模式:多集群模式: 二、安装的步骤1、预备环境准备2、载安装包以及安装2.1、Nacos有以下两种安装方式:2.2、更换数据源数据源切换为MySQL 2.3、开启控制台授权登录&#xff08;可选&#xff09; 3、配置中心的使用3.1、创建配置信…

3.27每日一题(常系数线性非齐次方程的特解)

常系数非齐次线性方程的特解如何假设&#xff08;两种&#xff09;形式&#xff1a; 1、题目中 e 的 x 次幂以及 1&#xff0c;都是第一种&#xff1a;1可以看成为e的0次幂 注&#xff1a;题目给的多项式是特殊的形式&#xff0c;我们要设为一般的形式的多项式 2、题目中sin…

竞赛 深度学习疫情社交安全距离检测算法 - python opencv cnn

文章目录 0 前言1 课题背景2 实现效果3 相关技术3.1 YOLOV43.2 基于 DeepSort 算法的行人跟踪 4 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; **基于深度学习疫情社交安全距离检测算法 ** 该项目较为新颖&#xff0c;适合作为竞赛…

搭建二维码系统,轻松实现固定资产的一物一码管理

固定资产管理中普遍存在盘点难、家底不清、账实不一致、权责不清晰等问题&#xff0c;可以在草料上搭建固定资产管理系统&#xff0c;通过组合功能模块实现资产信息展示、领用登记、出入库管理、故障报修等功能&#xff0c;对固定资产进行一物一码规范化管理。 比如张掖公路事业…

【webrtc】 对视频质量的码率控制的测试与探索

目录 环境设置 transport-cc goog-remb (webrtc中的两种码率算法&#xff09; 修改成remb算法 测试 效果 后续 可参考工程 环境设置 要到meshx上操作 telnet 112 然后执行factory_env show |grep meshx_ip 之后telnet meshx_ip 用户名admin 密码****.119 执行一下r…

self.register_buffer方法使用解析(pytorch)

self.register_buffer就是pytorch框架用来保存不更新参数的方法。 列子如下&#xff1a; self.register_buffer("position_emb", torch.randn((5, 3)))第一个参数position_emb传入一个字符串&#xff0c;表示这组参数的名字&#xff0c;第二个就是tensor形式的参数…

JavaEE平台技术——MyBatis

JavaEE平台技术——MyBatis 1. 对象关系映射框架——Hibernate、MyBatis2. 对象关系模型映射3. MyBatis的实现机制4. MyBatis的XML定义5. Spring事务 在观看这个之前&#xff0c;大家请查阅前序内容。 &#x1f600;JavaEE的渊源 &#x1f600;&#x1f600;JavaEE平台技术——…

大数据毕业设计选题推荐-设备环境监测平台-Hadoop-Spark-Hive

✨作者主页&#xff1a;IT毕设梦工厂✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

AI:57-基于机器学习的番茄叶部病害图像识别

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

Web前端—网页制作(以“学成在线”为例)

版本说明 当前版本号[20231105]。 版本修改说明20231105初版 目录 文章目录 版本说明目录day07-学成在线01-项目目录02-版心居中03-布局思路04-header区域-整体布局HTML结构CSS样式 05-header区域-logo06-header区域-导航HTML结构CSS样式 07-header区域-搜索布局HTML结构CSS…

挑战100天 AI In LeetCode Day02(1)

挑战100天 AI In LeetCode Day02&#xff08;1&#xff09; 一、LeetCode介绍二、LeetCode 热题 HOT 100-32.1 题目2.2 题解 三、面试经典 150 题-33.1 题目3.2 题解 一、LeetCode介绍 LeetCode是一个在线编程网站&#xff0c;提供各种算法和数据结构的题目&#xff0c;面向程序…

使用Objective-C和ASIHTTPRequest库进行Douban电影分析

概述 Douban是一个提供图书、音乐、电影等文化内容的社交网站&#xff0c;它的电影频道包含了大量的电影信息和用户评价。本文将介绍如何使用Objective-C语言和ASIHTTPRequest库进行Douban电影分析&#xff0c;包括如何获取电影数据、如何解析JSON格式的数据、如何使用代理IP技…

【JavaEE】JVM 剖析

JVM 1. JVM 的内存划分2. JVM 类加载机制2.1 类加载的大致流程2.2 双亲委派模型2.3 类加载的时机 3. 垃圾回收机制3.1 为什么会存在垃圾回收机制?3.2 垃圾回收, 到底实在做什么?3.3 垃圾回收的两步骤第一步: 判断对象是否是"垃圾"第二步: 如何回收垃圾 1. JVM 的内…

计算机网络第4章-网络层(1)

引子 网络层能够被分解为两个相互作用的部分&#xff1a; 数据平面和控制平面。 网络层概述 路由器具有截断的协议栈&#xff0c;即没有网络层以上的部分。 如下图所示&#xff0c;是一个简单网络&#xff1a; 转发和路由选择&#xff1a;数据平面和控制平面 网络层的作用…