Torch分布式训练

news/2024/4/23 14:41:37/文章来源:https://blog.csdn.net/frighting_ing/article/details/130322905

介绍

torch.nn.DataParallel

torch.nn.DataParallel 是 PyTorch 中的一个模块,可以用于在多个 GPU 上并行地训练神经网络。具体来说,它可以将单个模型复制到多个 GPU 上,并且在每个 GPU 上运行相同的操作,最后将各个 GPU 上的梯度进行求和并更新模型参数。这样,可以显著加速神经网络的训练过程。

使用 torch.nn.DataParallel 很简单。只需在定义模型时,将模型包装在 torch.nn.DataParallel 中即可。例如:

import torch.nn as nnmodel = nn.DataParallel(MyModel())

这将会将 MyModel() 复制到多个 GPU 上,并且在每个 GPU 上并行运行相同的操作。

需要注意的是,如果你使用的是 PyTorch 1.6 及以上版本,则不必使用 torch.nn.DataParallel,因为 PyTorch 已经内置了更高级别的分布式训练模块,如 torch.nn.parallel.DistributedDataParallel。这些模块提供了更好的性能和更灵活的配置选项,可以更好地满足各种分布式训练的需求。

torch.nn.parallel.DistributedDataParallel

torch.nn.parallel.DistributedDataParallel 是 PyTorch 中的一个模块,可以用于在分布式环境中并行地训练神经网络。与 torch.nn.DataParallel 不同,torch.nn.parallel.DistributedDataParallel 可以支持跨进程、跨机器的分布式训练,可以在多个计算机上同时训练神经网络,可以显著加速训练过程。

使用torch.nn.parallel.DistributedDataParallel需要进行以下步骤:

  1. 启动进程组:在分布式训练中,需要使用进程组(process group)来进行进程之间的通信。可以使用torch.distributed.init_process_group()函数来启动进程组,需要指定进程组的类型(如 torch.distributed.Backend.GLOO 或 torch.distributed.Backend.NCCL)、进程组的名称、进程组中进程的数量、当前进程的编号等参数。

  2. 加载数据集:在分布式训练中,每个进程需要读取一部分数据集,并且需要对数据集进行划分,以保证每个进程读取到的数据不重复、不遗漏。可以使用 PyTorch 提供的 DistributedSampler 来实现数据集的划分,还可以使用 DataLoader 加载数据集。

  3. 定义模型:在分布式训练中,需要确保模型在每个进程中都能够被正确地初始化。可以在每个进程中定义相同的模型,或者在主进程中定义模型,然后使用 PyTorch 提供的torch.nn.parallel.DistributedDataParallel对模型进行封装。

  4. 训练模型:在分布式训练中,需要确保每个进程都能够并行地进行前向传播、反向传播和参数更新。可以使用 PyTorch 提供的 backward() 和 step() 函数实现反向传播和参数更新,还可以使用 all_reduce() 函数将各个进程的梯度进行求和。

  5. 结束训练:在分布式训练中,需要确保进程组能够正确地结束。可以使用 torch.distributed.destroy_process_group() 函数来关闭进程组。

需要注意的是,使用 torch.nn.parallel.DistributedDataParallel 需要对代码进行一定的修改,例如需要添加启动进程组、加载数据集、定义模型等步骤,同时需要考虑数据划分、梯度同步等问题。因此,使用 torch.nn.parallel.DistributedDataParallel 需要一定的分布式编程知识和经验。

实例

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel# 启动进程组
dist.init_process_group(backend='gloo', init_method='file:///tmp/some_file', world_size=4, rank=0)# 加载数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_sampler = DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2, sampler=train_sampler)# 定义模型
model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(64 * 16 * 16, 10)
)
model = DistributedDataParallel(model)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 训练模型
for epoch in range(10):train_sampler.set_epoch(epoch)for i, (inputs, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 结束训练
dist.destroy_process_group()

这个例子中使用了 CIFAR10 数据集,定义了一个简单的卷积神经网络模型,并使用 torch.nn.parallel.DistributedDataParallel 将模型进行了封装。然后使用 DistributedSampler 对数据集进行了划分,并使用 DataLoader 加载数据集。在训练过程中,使用了 backward() 和 step() 函数进行反向传播和参数更新,并使用 all_reduce() 函数将各个进程的梯度进行求和。最后使用 torch.distributed.destroy_process_group() 函数结束进程组。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_102268.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

研究生考试 之 计算机网络第七版(谢希仁) 第一章 课后答案

研究生考试 之 计算机网络第七版(谢希仁) 第一章 课后答案 目录 研究生考试 之 计算机网络第七版(谢希仁) 第一章 课后答案 一、简单介绍 二、计算机网络第七版(谢希仁) 第一章 课后答案 1、 计算机网络向用户可以提供哪些服务? 2、 试简述分组交换的要点。 3…

Kali下部署-Nessus漏扫工具

Nessus 是全世界最多人使用的系统漏洞扫描与分析软件。总共有超过75,000个机构使用Nessus 作为扫描该机构电脑系统的软件。 特点: 1、提供完整的电脑漏洞扫描服务,并随时更新漏洞库。 2、可以在本机或者是远端上进行遥控,进行系统的漏洞扫…

常见的四种排名函数的用法(sql)

四个排名函数: 1.row_number 2.rank 3.dense_rank 4.ntile 1. ROW_NUMBER(排名场景推荐) 1.1 介绍 在 SQL 中,ROW_NUMBER() 是一个窗口函数,它为结果集中的每一行分配一个唯一的序号。该函数的语法如下: …

JavaSE-part1

文章目录 Day01 面向对象特性1.java继承注意点2.多态2.1多态概述2.2多态中成员的特点:star::star:2.3多态的转型:star::star: 3.Super4.方法重写:star::star:5.Object类:star::star: Day02 面向对象特性1.代码块:star:(主要是初始化变量,先于构造器)2.单例设计模式:…

【移动端网页布局】移动端网页布局基础概念 ⑦ ( 在 PhotoShop 中使用 Cutterman 切二倍图 | 使用二倍图作为背景图像 )

文章目录 一、在 PhotoShop 中使用 Cutterman 切二倍图二、使用二倍图作为背景图像 一、在 PhotoShop 中使用 Cutterman 切二倍图 参考 【CSS】PhotoShop 切图 ③ ( PhotoShop 切图插件 - Cutterman | 下载、安装、启动、注册、登录 Cutterman - 切图神奇 插件 | 使用插件进行切…

3自由度并联绘图机器人实现写字功能(一)

1. 功能说明 本文示例将实现R305样机3自由度并联绘图机器人写字的功能。 2. 电子硬件 在这个示例中,采用了以下硬件,请大家参考: 主控板 Basra主控板(兼容Arduino Uno) 扩展板Bigfish2.1扩展板电池7.4V锂电池 3. 功能…

远程访问及控制ssh

SSH远程管理 OpenSSH服务器 SSH(Secure Shell) 协议 是一种安全通道协议。主要用来实现字符界面的远程登录、远程复制等功能。对通信数据进行了加密处理,用于远程管理其中包括用户登录时输入的用户口令。因此SSH协议具有很好的安全性------------(同样…

d2l Transformer

终于到变形金刚了,他的主要特征在于多头自注意力的使用,以及摒弃了rnn的操作。 目录 1.原理 2.多头注意力 3.逐位前馈网络FFN 4.层归一化 5.残差连接 6.Encoder 7.Decoder 8.训练 9.预测 1.原理 主要贡献:1.纯使用attention的Enco…

Android程序员向音视频进阶,有前景吗

随着移动互联网的普及和发展,Android开发成为了很多人的就业选择,希望在这个行业能获得自己的一席之地。然而,随着时间的推移,越来越多的人进入到了Android开发行业,就导致目前Android开发的工作越来越难找&#xff0c…

EFI Driver Model(下)-USB 驱动设计

1、USB简介 通用串行总线(英语:Universal Serial Bus,缩写:USB)是一种串口总线标准,也是一种输入输出接口的技术规范,被广泛地应用于个人电脑和移动设备等信息通讯产品,并扩展至摄影…

我看谁没看过

vue在新窗口打开页面方法 const { href } this.$router.resolve({path: "/officePlatform/addPrompt"});window.open(href, "_blank"); 添加圆形标志 h3::before {content: "";display: inline-block;width: 13px;height: 13px;background: va…

NFT介绍及监管规则

什么是NFT NFT是Non-Fungible Token(非同质化代币)的缩写。 NFT是“Non-Fungible Token”的缩写,即非同质化代币。不同于FT(Fungible Token,同质化代币),每一个NFT都是独一无二且不可相互替代的…

第二章 Maven 核心程序解压和配置

第一节 Maven核心程序解压与配置 1、Maven 官网地址 首页: Maven – Welcome to Apache Maven(opens new window) 下载页面: Maven – Download Apache Maven(opens new window) 下载链接: 具体下载地址:https://dlcdn.apac…

【云原生】Java 应用程序在 Kubernetes 上棘手的内存管理

文章目录 引言JVM 内存模型简介非 Heap 内存Heap 堆内存Kubernetes 内存管理JVM 和 Kubernetes场景 1 — Java Out Of Memory 错误场景 2 — Pod 超出内存 limit 限制场景 3 — Pod 超出节点的可用内存场景 4 — 参数配置良好,应用程序运行良好 结语 引言 如何结合…

三月、四月总计面试碰壁15次,作为一个27岁的测试工程师.....

3年测试经验原来什么都不是,只是给你的简历上画了一笔,一直觉得经验多,无论在哪都能找到满意的工作,但是现实却是给我打了一个大巴掌!事后也不会给糖的那种... 先说一下自己的个人情况,普通二本计算机专业…

JVM调优最佳参数

项目背景 C端的项目,用户量比较多,请求比较多。 启动参数表 Xmx指定应用程序可用的最大堆大小。 Xms指定应用程序可用的最小堆大小。 (一般情况下,需要设置Xmx和Xms为相等的值,且为一个固定的值) 如果该值…

图像处理:均值滤波算法

目录 前言 概念介绍 基本原理 Opencv实现中值滤波 Python手写实现均值滤波 参考文章 前言 在此之前,我曾在此篇中推导过图像处理:推导五种滤波算法(均值、中值、高斯、双边、引导)。这在此基础上,我想更深入地研…

使用状态机实现幂等性

文章目录 背景幂等概念适用场景示例代码上述代码状态流转 背景 在某些场景下,可以使用状态机来实现幂等性。将业务流程抽象为一个状态机,定义各个状态之间的转换规则。当收到一个请求时,根据当前状态和请求类型来判断是否允许执行操作&#x…

数学知识四

容斥原理 S表示面积,下面公式可求出不相交的面积 2个圆的公式是这样 4个圆的面积是 总面积-所有俩俩相交的面积所有三三相交的面积-四四相交的面积,公式里加和减互相出现。 从n个集合里面挑一个一直到从n个集合里面挑n个 1-10中,能被2&#x…

【 SpringBoot单元测试 和 Mybatis 增,删,改 操作 】

文章目录 一、Spring-Boot单元测试(了解)1.1 概念1.2 单元测试引用1.3 单元测试的实现1.4 简单的断言说明1.5 单元测试优点 二、Mybatis 增,删,改 操作2.1 增加⽤户操作2.2 修改⽤户操作2.3 删除⽤户操作 一、Spring-Boot单元测试(了解) 1.1 概念 单元测…