基于tensorflow2.x的多GPU并行训练

news/2024/3/29 12:59:51/文章来源:https://blog.csdn.net/weixin_45885232/article/details/130266596
由于最近训练transformer,在单卡上显存不够,另外一块卡上也无法加载,故尝试使用双卡并行的策略。将基本的流程、遇见的难题汇总在这里。

双卡满载

分布策略解释

使用官方给出的tf.distribute.MirroredStrategy作为分布策略。这个策略通过如下的方式运行:
1)所有变量和模型计算图都会在副本之间复制。
2)输入都均匀分布在副本中。
3)每个副本在收到输入后计算输入的损失和梯度。
4)通过求和,每一个副本上的梯度都能同步。
5)同步后,每个副本上的复制的变量都可以同样更新。

正文

初始化分布策略

可以使用如下的命令,查看当前设备有几块GPU可以供使用。

strategy = tf.distribute.MirroredStrategy()
print(strategy.num_replicas_in_sync)

一、数据加载

使用分布式训练,会将总的batch分散到多块GPU上。我这里有两块GPU,使用的batch是32,那么在每个上面就是16。这里,在数据加载的时候就需要做处理,具体处理过程如下:

1)创建一个总的batchsize
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
2) 加载数据集
train_ds = DataLoader().make_batch(PARA['train_'], GLOBAL_BATCH_SIZE, PARA['max_len_sequence'])
valid_ds = DataLoader().make_batch(PARA['vaild_'], GLOBAL_BATCH_SIZE, PARA['max_len_sequence'])
test_ds = DataLoader().make_batch(PARA['testt_'],  GLOBAL_BATCH_SIZE, PARA['max_len_sequence'])

3)对数据做分发

train_ds = strategy.experimental_distribute_dataset(train_ds)
valid_ds = strategy.experimental_distribute_dataset(valid_ds)
test_ds = strategy.experimental_distribute_dataset(test_ds)

经过上面这些操作,数据已经处理好了,接下来处理训练策略。

二、 定义损失函数

:这里有几个地方需要特别注意,tf.losses/tf.keras.losses 中的损失函数通常会返回输入最后一个维度的平均值。损失类封装这些函数。在创建损失类的实例时传递 reduction=Reduction.NONE,表示“无额外缩减”。对于样本输入形状为 [batch, W, H, n_classes] 的类别损失,会缩减 n_classes 维度。对于类似 losses.mean_squared_errorlosses.binary_crossentropy 的逐点损失,应包含一个虚拟轴,使 [batch, W, H, 1] 缩减为 [batch, W, H]。如果没有虚拟轴,则 [batch, W, H] 将被错误地缩减为 [batch, W]
增加虚拟轴的方式也很简单,labels = labels[:, tf.newaxis]如果没有这个,回归模型是跑不起来的!!!

1)使用 tf.distribute.Strategy 时应如何计算损失?

例如,假设有 2 个 GPU,批次大小为 64。一个批次的输入会分布在各个副本(2 个 GPU)上,每个副本获得一个大小为 32 的输入。

每个副本上的模型都会使用其各自的输入进行前向传递,并计算损失。现在,不将损失除以其相应输入中的样本数 (BATCH_SIZE_PER_REPLICA = 32),而应将损失除以 GLOBAL_BATCH_SIZE (64)

之所以需要这样做,是因为在每个副本上计算完梯度后,会通过对梯度求和在副本之间同步梯度。

2)计算方法

如果使用自定义训练循环,则应将每个样本的损失相加,然后将总和除以 GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE),或者使用 tf.nn.compute_average_loss,它会将每个样本的损失、可选样本权重和 GLOBAL_BATCH_SIZE 作为参数,并返回经过缩放的损失。比较而言,选择tf.nn.compute_average_loss这个会好一些。

由于我这里使用的是 tf.keras.losses 类,则需要将损失归约显式指定NONE 或 SUM。与 tf.distribute.Strategy 一起使用时,不允许使用 AUTO 和 SUM_OVER_BATCH_SIZE不允许使用 AUTO,因为用户应明确考虑他们想要的归约量,以确保在分布式情况下归约量正确。不允许使用 SUM_OVER_BATCH_SIZE,因为当前它只能按副本批次大小进行划分,而将按副本数量划分留给用户,这可能很容易遗漏。因此,您需要自己显式执行归约操作。

我做的是回归任务,具体的代码如下,可以看到,loss损失里面使用了reduction=tf.keras.losses.Reduction.NONE,返回损失值的时候使用了tf.nn.compute_average_loss

GLOBAL_BATCH_SIZE = PARA['batch_size']*strategy.num_replicas_in_sync
with strategy.scope():# Set reduction to `NONE` so you can do the reduction afterwards and divide by# global batch size.loss_object = tf.keras.losses.Huber(reduction=tf.keras.losses.Reduction.NONE)def compute_loss(labels, predictions):# 这里有个坑,见最开始的注# 使用Reduction.NONE之后,回归损失会减少一个维度,故要在后面添加一列# https://tensorflow.google.cn/tutorials/distribute/custom_training?hl=zh-cnlabels = labels[:,tf.newaxis]predictions = predictions[:, tf.newaxis]per_example_loss = loss_object(labels, predictions)return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

三、定义评价指标

评价指标根据自己的实际情况来,我这里使用了loss跟rmse

with strategy.scope():test_loss = tf.keras.metrics.Mean(name='test_loss')train_rmse = tf.keras.metrics.RootMeanSquaredError(name='train_rmse')test_rmse = tf.keras.metrics.RootMeanSquaredError(name='test_rmse')

四、初始化模型

模型、优化器和checkpoint务必要放在strategy.scope

with strategy.scope():model = Transformer(PARA['num_layers'], PARA['input_vocab_size'], PARA['target_vocab_size'],PARA['target_class'],PARA['max_len_sequence'],PARA['d_model'],PARA['num_heads'],PARA['dff'],rate=PARA['dropout_rate'])# 加载优化器:learning_rate = CustomizedSchedule(PARA['d_model'])optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)# 记录模型check = tf.train.Checkpoint(model=model, optimizer=optimizer)check_manager = tf.train.CheckpointManager(check, PARA['model_save'], max_to_keep=5)if check_manager.latest_checkpoint:check.restore(check_manager.latest_checkpoint)

五、构建训练策略

1) 先构建并行的策略,再构建train_step

with strategy.scope():# `run` replicates the provided computation and runs it# with the distributed input.@tf.functiondef distributed_train_step(dataset_inputs):per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)@tf.functiondef distributed_test_step(dataset_inputs):return strategy.run(test_step, args=(dataset_inputs,))

2) 构建train_step

def train_step(inputs):train_rmse.reset_states()sequence, tm, label = inputswith tf.GradientTape() as tape:predictions = model(sequence, training=True)loss = compute_loss(tm, predictions)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))train_rmse.update_state(tm, predictions)return lossdef test_step(inputs):sequence, tm, label = inputspredictions = model(sequence, training=False)t_loss = loss_object(tm, predictions)test_loss.update_state(t_loss)test_rmse.update_state(tm, predictions)

六、自定义训练过程

def fit(train_ds, valid_ds, test_ds):steps = 0start = time.time()for epoch in range(PARA['EPOCH']):# TRAIN LOOPtotal_loss, num_batches, batch = 0.0, 0, 0for (batch, x) in enumerate(train_ds):# 这里返回每一个批次的损失值per_loss= distributed_train_step(x)total_loss += per_losssteps += 1# 这是自定义的记录函数,可以直接print当前值save_smurry('train','-', epoch, batch, steps, [per_loss, train_rmse.result()])if batch % (PARA['REPORT_STEP']*2) == 0 and batch:# 每次处理完之后,需要对test_loss及test_rmse做重置for (batch, x) in enumerate(valid_ds):distributed_test_step(x)# 这里需要得到的是在整个验证集上的结果save_smurry('vaild','-', epoch, batch, steps, [test_loss.result(), test_rmse.result()])test_loss.reset_states()test_rmse.reset_states()# 每50次做一次benchmark验证if batch % (PARA['REPORT_STEP']*5) == 0 and batch:for x in test_ds:distributed_test_step(x)save_smurry('test','-', epoch, batch, steps, [test_loss.result(), test_rmse.result()])test_loss.reset_states()test_rmse.reset_states()time_used = 'Time take for 1 epoch:{} secs\n'.format(time.time()-start)fout(time_used)

至此,分布程序构建完成。欢迎一起讨论

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_102513.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【架构设计】什么是CAP理论?

1、理论 CAP理论是指计算机分布式系统的三个核心特性:一致性(Consistency)、可用性(Availability)和分区容错性(Partition Tolerance)。 在CAP理论中,一致性指的是多个节点上的数据…

宝安日报:联诚发跨界创新“追光”十九载!

世界一流声光电智造一体化服务商、国家级高新技术企业、国家级专精特新“小巨人”企业、博士后创新实践基地、深圳自主创新百强企业……这些熠熠生辉的关键词,是位于宝安区航城街道的深圳市联诚发科技股份有限公司(以下简称:联诚发&#xff0…

KingSCADA3.8保姆级安装教程

大家好,我是雷工! 最近开始学习KingSCADA,今天这篇详细记录安装KingSCADA3.8的过程。 首先下载需要的安装版本,此处以从官网下载的最新版本KingSCADA3.8为例,双击:Setup.exe ; 一、安装主程序 1、点击“…

AutoSAR内存映射

总目录链接>> AutoSAR入门和实战系列总目录 总目录链接>> AutoSAR BSW高阶配置系列总目录 文章目录 为了防止不必要的内存缺口(RAM 中未使用的空间),不同大小(8、16 和 32 位)的变量根据其大小映射到特…

工业树莓派远程I/O控制套装—更高效、更灵活、更便捷

一、背景 在完整的生产过程中,许多传感器设备和执行设备不完全安装在同一位置,大多分散部署在各个生产环节中。如果采用本地控制的方式,就需要用到多个控制器,但是成本较高,且不利于管理,所以最理想的解决…

Vue表单基本操作-收集表单数据

收集表单数据 使用vue中的v-model收集表单里面的数据,不同的表单元素配合v-model会有不同的写法和技巧 本次的表单元素包括:文本框,单选,多选,下拉框,文本域 编写表单元素 首先编写表单元素,…

ROS学习第三十七节——机器人运动控制以及里程计信息显示

https://download.csdn.net/download/qq_45685327/87719766 https://download.csdn.net/download/qq_45685327/87719873 gazebo 中已经可以正常显示机器人模型了,那么如何像在 rviz 中一样控制机器人运动呢?在此,需要涉及到ros中的组件: ros…

camunda的service task如何使用

在 Camunda 中,使用 Service Task 节点可以执行各种类型的业务逻辑,例如计算、数据转换、数据格式化等。在 Service Task 节点中,可以使用不同的编程语言来实现业务逻辑,例如 Java、JavaScript、Python 等。 下面是使用 Java 实现…

状态压缩DP-蒙德里安的梦想

题意 求把 NM 的棋盘分割成若干个 12 的长方形,有多少种方案。 例如当 N2,M4 时,共有 5 种方案。当 N2,M3 时,共有 3 种方案。 如下图所示: 输入格式 输入包含多组测试用例。 每组测试用例占一行&#xff0…

这份最新阿里、腾讯、华为、字节等大厂的薪资和职级对比,你看过没?

互联网大厂新入职员工各职级薪资对应表(技术线)~ 最新阿里、腾讯、华为、字节跳动等大厂的薪资和职级对比 上面的表格不排除有很极端的收入情况,但至少能囊括一部分同职级的收入。这个表是“技术线”新入职员工的职级和薪资情况,非技术线(如产品、运营、…

【Linux】环境变量与进程优先级知识点

目录 环境变量1.基本概念2.常见环境变量3.我们写的程序和命令行指令有什么区别?4.自己的程序为什么要用 ./ 执行,而命令行指令可以直接执行?5.如何追加环境变量?6.Linux如何查看环境变量7.如何在代码层面获取环境变量main函数的参…

ubuntu 3060显卡驱动+cuda+cudnn+pytorch+pycharm+vscode

文章目录 运行环境:适用:思路:1.1 3060显卡驱动自动安装2.1 CUDA11.1.11)下载CUDA Toolkit 11.1 Update 1 Downloads2)contunue , 然后accept3)回车取消Driver安装,然后install4)添加环境变量5)确认是否安装成功 3.1 cudnn 8.1.11…

【Cartopy基础入门】如何更好的确定边界显示

原文作者:我辈理想 版权声明:文章原创,转载时请务必加上原文超链接、作者信息和本声明。 Cartopy基础入门 【Cartopy基础入门】Cartopy的安装 【Cartopy基础入门】Geojson数据的加载 【Cartopy基础入门】如何更好的确定边界显示 文章目录 Ca…

【边缘计算】登临(Goldwasser-UL64)BW-BR2边缘设备配置指南

目录 开箱配置激活SDK环境测试cuda兼容性 开箱配置 更改盒子root用户密码: sudo passwd root(密码同为root) 切换到root用户身份: su root查看ssh的状态,没有返回说明没有启动 sudo ps -e|grep ssh此时说明ssh服务已启动。 更改ssh配置文…

java定位系统源码,通过独特的射频处理,配合先进的位置算法,可以有效计算出复杂环境下的人员与物品的活动信息

智慧工厂人员定位系统源码,区域电子围栏管控源码 文末获取联系! 在工厂日常生产活动中,企业很难精准地掌握访客和承包商等各类人员的实际位置,且无法实时监控巡检人员的巡检路线,当厂区发生灾情或其他异常状况时&#…

postman安装

目录 下载、安装 Postman是一款功能强大的网页调试与发送网页HTTP请求的Chrome插件。 Postman原是Chrome浏览器的插件,可以模拟浏览器向后端服务器发起任何形式(如:get、post)的HTTP请求 使用Postman还可以在发起请求时,携带一些请求参数、请求头等信息…

WebSocket+Vue+SpringBoot实现语音通话

参考文章 整体思路 前端点击开始对话按钮后,将监听麦克风,获取到当前的音频,将其装化为二进制数据,通过websocket发送到webscoket服务端,服务端在接收后,将消息写入给指定客户端,客户端拿到发送…

日本PSE认证日本的電気用品安全法METI备案

日本的電気用品安全法(PSE认证)法规要求日本的采购商在购进商品后一个月内必须向日本METI注册申报,并必须将采购商名称或ID标在产品上,以便在今后产品销售过程中进行监督管理,完成后将获得電気用品製造事業届出書&…

Java基础学习(10)

Java基础学习 一、JDK8时间类1.1 Zoneld时区1.2 Instant时间戳1.3 ZonedDateTime1.4 DateTimeFormatter1.5 日历类时间表示1.6 工具类1.7 包装类JDK5提出的新特性Integer成员方法 二、集合进阶2.1 集合的体系结构2.1.1 Collection 2.2collection的遍历方式2.2.1 迭代器遍历2.2.…

元宇宙场景下的实时互动RTI技术能力构建

元宇宙可谓是处在风口浪尖,无数的厂商都对元宇宙未来抱有非常美好的憧憬。正因如此,许许多多厂商都在用他们自己的方案,为元宇宙更快、更好的实现,在自己的领域贡献力量。LiveVideoStack 2022北京站邀请到了 ZEGO 即构科技的解决方…