PyTorch学习笔记(十六)——利用GPU训练

news/2024/5/20 21:21:19/文章来源:https://blog.csdn.net/weixin_45827876/article/details/132391041

 一、方式一

网络模型、损失函数、数据(包括输入、标注)

找到以上三种变量,调用它们的.cuda(),再返回即可

if torch.cuda.is_available():mynn = mynn.cuda()
if torch.cuda.is_available():loss_function = loss_function.cuda()
for data in train_dataloader:imgs,targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()
for data in test_dataloader:imgs,targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()

完整代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
# from model import *# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="../datasets",train=True,transform=torchvision.transforms.ToTensor(),download=False)
test_data = torchvision.datasets.CIFAR10(root="../datasets",train=False,transform=torchvision.transforms.ToTensor(),download=False)train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用dataloader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型
class MyNN(nn.Module):def __init__(self):super(MyNN, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x
mynn = MyNN()
if torch.cuda.is_available():mynn = mynn.cuda()# 损失函数
loss_function = nn.CrossEntropyLoss()
if torch.cuda.is_available():loss_function = loss_function.cuda()
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(mynn.parameters(), lr=learning_rate)# 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 # 训练的轮数# 添加tensorboard
writer = SummaryWriter("../logs_train")start_time = time.time()
for i in range(epoch):print("----------第{}轮训练开始----------".format(i+1))# 训练步骤开始mynn.train()for data in train_dataloader:imgs,targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = mynn(imgs)loss = loss_function(outputs, targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print("所用时间:{}".format(end_time - start_time))print("训练次数:{},loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_step)# 测试步骤开始mynn.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs,targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = mynn(imgs)loss = loss_function(outputs, targets)total_test_loss += lossaccuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,total_test_step)writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)total_test_step += 1torch.save(mynn,"mynn_{}.pth".format(i))# torch.save(mynn.state_dict(),"mynn_{}.pth".format(i))print("模型已保存")writer.close()

 比较CPU和GPU的训练时间:

 查看GPU信息:

在 终端里输入nvidia-smi

 使用Google Colab:Google 为我们提供了一个免费的GPU

修改 ——> 笔记本设置 ——> 硬件加速器选择GPU(每周免费使用30h)

 

 

 二、方式二(更常用)

定义训练设备

device = torch.device("cpu")
# 对于单显卡来说,以下两种方式没有区别
device = torch.device("cuda")
device = torch.device("cuda:0")
# 一种语法的简写,程序在 CPU 或 GPU/cuda 环境下都能运行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

网络模型、损失函数、数据(包括输入、标注)

找到以上三种变量,.to(device),再返回即可

mynn = MyNN()
mynn = mynn.to(device)
# 这里可以不用再赋值给mynn,直接mynn.to(device) 也可以
loss_function = nn.CrossEntropyLoss()
loss_function = loss_function.to(device)
# 这里可以不用再赋值给loss_function ,直接loss_function .to(device) 也可以
for data in train_dataloader:imgs,targets = dataimgs = imgs.to(device)targets = targets.to(device)# 这里必须赋值
for data in test_dataloader:imgs,targets = dataimgs = imgs.to(device)targets = imgs.to(device)# 这里必须赋值

完整代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
# from model import *# 定义训练的设备
# device = torch.device("cpu")
# device = torch.device("cuda")
# device = torch.device("cuda:0")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="../datasets",train=True,transform=torchvision.transforms.ToTensor(),download=False)
test_data = torchvision.datasets.CIFAR10(root="../datasets",train=False,transform=torchvision.transforms.ToTensor(),download=False)train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用dataloader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型
class MyNN(nn.Module):def __init__(self):super(MyNN, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x
mynn = MyNN()
mynn.to(device)# 损失函数
loss_function = nn.CrossEntropyLoss()
loss_function.to(device)
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(mynn.parameters(), lr=learning_rate)# 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 # 训练的轮数# 添加tensorboard
writer = SummaryWriter("../logs_train")start_time = time.time()
for i in range(epoch):print("----------第{}轮训练开始----------".format(i+1))# 训练步骤开始mynn.train()for data in train_dataloader:imgs,targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = mynn(imgs)loss = loss_function(outputs, targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print("所用时间:{}".format(end_time - start_time))print("训练次数:{},loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_step)# 测试步骤开始mynn.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs,targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = mynn(imgs)loss = loss_function(outputs, targets)total_test_loss += lossaccuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,total_test_step)writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)total_test_step += 1torch.save(mynn,"mynn_{}.pth".format(i))# torch.save(mynn.state_dict(),"mynn_{}.pth".format(i))print("模型已保存")writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_536492.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1339. 分裂二叉树的最大乘积

链接: ​​​​​​1339. 分裂二叉树的最大乘积 题解: /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode() : val(0), left(nullptr), right(nullptr) {}* …

Java性能分析中常用命令和工具

当涉及到 Java 性能分析时,有一系列强大的命令和工具可以帮助开发人员分析应用程序的性能瓶颈、内存使用情况和线程问题。以下是一些常用的 Java 性能分析命令和工具,以及它们的详细说明和示例。 以下是一些常用的性能分析命令和工具汇总: …

Nacos配置管理、Feign远程调用、Gateway服务网关

1.Nacos配置管理 1.1.将配置交给Nacos管理的步骤 1.在Nacos中添加配置 Data Id服务名称-环境名称.yaml eg&#xff1a;userservice-dev.yaml 2.引入nacos-config依赖 在user-service服务中&#xff0c;引入nacos-config的客户端依赖 <!--nacos配置管理依赖--> <dep…

redis--主从复制

redis主从复制 Redis 主从复制是一种用于实现数据复制和数据备份的机制&#xff0c;它允许将一个 Redis 服务器的数据复制到其他 Redis 服务器上。主从复制在 Redis 中通常用于构建高可用性架构、读写分离以及数据分析等场景。 主从复制的角色 主服务器&#xff08;Master&a…

系统架构设计专业技能 · 软件工程之需求工程

系列文章目录 系统架构设计高级技能 软件架构概念、架构风格、ABSD、架构复用、DSSA&#xff08;一&#xff09;【系统架构设计师】 系统架构设计高级技能 系统质量属性与架构评估&#xff08;二&#xff09;【系统架构设计师】 系统架构设计高级技能 软件可靠性分析与设计…

Cpp学习——类与对象3

目录 一&#xff0c;初始化列表 1.初始化列表的使用 2.初始化列表的特点 3.必须要使用初始化列表的场景 二&#xff0c;单参数构造函数的隐式类型转换 1.内置类型的隐式类型转换 2. 自定义类型的隐式类型转换 3.多参数构造函数的隐式类型转换 4.当你不想要发生隐式类型转换…

Unity VR:XR Interaction Toolkit 输入系统(Input System):获取手柄的输入

文章目录 &#x1f4d5;教程说明&#x1f4d5;Input System 和 XR Input Subsystem&#xff08;推荐 Input System&#xff09;&#x1f4d5;Input Action Asset⭐Actions Maps⭐Actions⭐Action Properties&#x1f50d;Action Type (Value, Button, Pass through) ⭐Binding …

数据结构<树和二叉树>顺序表存储二叉树实现堆排

✨Blog&#xff1a;&#x1f970;不会敲代码的小张:)&#x1f970; &#x1f251;推荐专栏&#xff1a;C语言&#x1f92a;、Cpp&#x1f636;‍&#x1f32b;️、数据结构初阶&#x1f480; &#x1f4bd;座右铭&#xff1a;“記住&#xff0c;每一天都是一個新的開始&#x1…

Gin+微服务实现抖音视频上传到七牛云

文章目录 安装获取凭证Gin处理微服务处理 如果你对Gin和微服务有一定了解&#xff0c;看本文较容易。 安装 执行命令&#xff1a; go get github.com/qiniu/go-sdk/v7获取凭证 Go SDK 的所有的功能&#xff0c;都需要合法的授权。授权凭证的签算需要七牛账号下的一对有效的A…

Go语言入门指南:基础语法和常用特性解析(上)

一、Go语言前言 Go是一种静态类型的编译语言&#xff0c;常常被称作是21世纪的C语言。Go语言是一个开源项目&#xff0c;可以免费获取编译器、库、配套工具的源代码&#xff0c;也是高性能服务器和应用程序的热门选择。 Go语言可以运行在类UNIX系统——比如Linux、OpenBSD、M…

Python批量爬虫下载文件——把Excel中的超链接快速变成网址

本文的背景是&#xff1a;大学关系很好的老师问我能不能把Excel中1000个超链接网址对应的pdf文档下载下来。虽然可以手动一个一个点击下载&#xff0c;但是这样太费人力和时间了。我想起了之前的爬虫经验&#xff0c;给老师分析了一下可行性&#xff0c;就动手实践了。    没…

Platypus:Quick,Cheap,and Powerful Refinement of LLMs

Platypus:Quick,Cheap,and Powerful Refinement of LLMs IntroductionMethod2.1 Curating Open- PlatypusRemoving similar&duplicate questionsContamination CheckFine-tuning & mergingResult参考Introduction 现在大模型已经取得很不错的结果,如何把大模型的能…

sh 脚本循环语句和正则表达式

目录 1、循环语句 1、for 2、while 3、until 2、正则表达式 1、元字符 2、表示次数 3、位置锚定 4、分组 5、扩展正则表达式 1、循环语句 循环含义 将某代码段重复运行多次&#xff0c;通常有进入循环的条件和退出循环的条件 重复运行次数 循环次数事先已知 循环次…

1、攻防世界第一天

1、网站目录下会有一个robots.txt文件&#xff0c;规定爬虫可以/不可以爬取的网站。 2、URL编码细则&#xff1a;URL栏中字符若出现非ASCII字符&#xff0c;则对其进行URL编码&#xff0c;浏览器将该请求发给服务端&#xff1b;服务端会可能会先对收到的url进行解码&#xff0…

使用 Amazon Redshift Serverless 和 Toucan 构建数据故事应用程序

这是由 Toucan 的解决方案工程师 Django Bouchez与亚马逊云科技共同撰写的特约文章。 带有控制面板、报告和分析的商业智能&#xff08;BI&#xff0c;Business Intelligence&#xff09;仍是最受欢迎的数据和分析使用场景之一。它为业务分析师和经理提供企业的过去状态和当前状…

字符设备驱动实例(PWM和RTC)

目录 五、PWM 六、RTC 五、PWM PWM(Pulse Width Modulation&#xff0c;脉宽调制器)&#xff0c;顾名思义就是一个输出脉冲宽度可以调整的硬件器件&#xff0c;其实它不仅脉冲宽度可调&#xff0c;频率也可以调整。它的核心部件是一个硬件定时器&#xff0c;其工作原理可以用…

抖音火山引擎推出免费域名DNS和公共DNS服务

抖音旗下的云计算服务火山引擎最近推出了"TrafficRoute DNS 套件"服务&#xff0c;其中包括两款产品&#xff0c;对软希网来说非常有用。 1.域名DNS&#xff1a; 这是一个用于网站域名的DNS服务&#xff0c;可以加速域名解析速度&#xff0c;从而提升网站的速度。如…

初出茅庐的小李博客之STM32CubeMx驱动WS2812B实现幻彩(超详)

STM32CubeMx驱动WS2812B实现幻彩&#xff08;超详&#xff09; 1.创建基于STM32F03C8T6工程 1.1配置时钟 选择外部高速时钟源HSE 1.2配置系统时钟树使其达到最大时钟72MHz&#xff08;最大系统时钟&#xff09; 由时钟树可以知道APB1上定时器时钟频率是72MHz,实验使用的硬件…

[NLP] BERT模型参数量

一 BERT_Base 110M参数拆解 BERT_base模型的110M的参数具体是如何组成的呢&#xff0c;我们一起来计算一下&#xff1a; 刚好也能更深入地了解一下Transformer Encoder模型的架构细节。 借助transformers模块查看一下模型的架构&#xff1a; import torch from transformers …