【mT5多语言翻译】之五——训练:中央日志、训练可视化、PEFT微调

news/2024/4/30 6:48:41/文章来源:https://blog.csdn.net/qq_43592352/article/details/137616549

·请参考本系列目录:【mT5多语言翻译】之一——实战项目总览

[1] 模型训练与验证

  在上一篇实战博客中,我们讲解了访问数据集中每个batch数据的方法。接下来我们介绍如何训练mT5模型进行多语言翻译微调。

  首先加载模型,并把模型设置为训练状态,然后定义优化器、学习率衰减,并设置一些初始状态值。

model = AutoModelForSeq2SeqLM.from_pretrained(conf.pretrained_path)
model.train()optimizer = AdamW(model.parameters(), lr=config.learning_rate)
# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_iter) * config.num_epochs)
total_batch = 0  # 记录进行到多少batch
losses = []

【注】设置了model.train()后,会激活dropout。而设置model.eval()后,会关闭dropout,防止影响模型的推理结果。

  接着,我们就可以访问数据加载器中的数据进行模型的微调了:

    for epoch in range(config.num_epochs):for i, (input_batch, label_batch) in enumerate(train_iter):optimizer.zero_grad()model_out = model.forward(input_ids=input_batch, labels=label_batch)loss = model_out.losslosses.append(loss.item())loss.backward()optimizer.step()scheduler.step()# 打印步if (total_batch + 1) % config.print_step == 0:avg_loss = np.mean(losses[-config.print_step:])logger.info('Epoch: {} | Step: {} | Train Avg. loss: {:.3f} | lr: {} | Time: {}'.format(epoch + 1,total_batch + 1, avg_loss, scheduler.get_last_lr()[0], get_time_dif(start_time)))# 验证步if (total_batch + 1) % config.checkpoint_step == 0:test_loss = evaluate(model, dev_iter)torch.save(model.state_dict(), config.save_path)time_dif = get_time_dif(start_time)logger.info('Test Avg. loss: {:.3f} | Time: {} '.format(test_loss, time_dif))model.train()total_batch += 1torch.save(model.state_dict(), config.save_path)

  如上,训练的代码还是很少的,里面有一些注意事项需要详细说明。

  1、由于我们不是每一个batch都要输出一下损失到控制台。所有我们需要设置打印步,来控制每隔多少个batch输出一次模型的loss。因此需要一个数组将每次的训练损失记录下来,然后打印时求下平均值。

  2、同理,在验证集上测试性能也是这个道理。但是由于我们的数据集很大,即使模型在验证集上的loss不再下降,也不应该主动把模型停止。因为大模型的训练可以抽象为“压缩数据”的概念,它没见过的数据就是不会产生相应的知识,所以最好还是让模型一直训练下去,直到把数据集训练完。

  模型的验证代码属于训练的阉割版,比较简单,如下:

def evaluate(model, data_iter):model.eval()eval_losses = []with torch.no_grad():for input_batch, label_batch in data_iter:model_out = model.forward(input_ids=input_batch, labels=label_batch)eval_losses.append(model_out.loss.item())return np.mean(eval_losses)

[2] 中央日志

  由于模型训练的时间较长,一旦断联可能就无法再继续观测模型在控制台打印的输出。因此,我们需要设计一个日志功能,让模型即可以实时的打印输出,又能同时记录输出到文件中,以便于我们后期查看。

  还有一个待解决的问题是,项目中不同文件的输出日志,我们需要将其定位到同一个log日志中。

  所以需要在项目中配置根日志器。配置代码如下:

import logging.config# 定义日志配置
LOGGING_CONFIG = {'version': 1,'formatters': {'default': {'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',},},'handlers': {'console': {'class': 'logging.StreamHandler',  # 输出到控制台'level': 'INFO','formatter': 'default',},'file': {'class': 'logging.FileHandler',  # 输出到文件'filename': 'train.log','level': 'DEBUG','formatter': 'default',},},'loggers': {'': {  # root logger'handlers': ['console', 'file'],'level': 'DEBUG',},},
}# 在根日志器(没有名称的日志器)上应用日志配置,那么所有子日志器都会继承这些设置
logging.config.dictConfig(LOGGING_CONFIG)
# 获取日志器
logger = logging.getLogger(__name__)

  我们在项目的任意一个py文件中写下这样的配置。然后在其他py文件中只需要2行代码即可让输出流即展示在控制台也能保存在同一个log日志文件中:

# 获取日志器
logger = logging.getLogger(__name__)
# 记录日志
logger.info("xxxxxx")

[3] 训练可视化

  我们在训练过程中,会有实时观测损失波动、评价指标波动的需求。因此项目集成了tensorBoardX来进行训练可视化。

  先上效果图:
在这里插入图片描述

  首先安装tensorboard,输入命令:

pip install tensorboard 

  然后在代码中使用write即可,代码demo:

import numpy as np
from torch.utils.tensorboard import SummaryWriter  # 也可以使用 tensorboardX
# from tensorboardX import SummaryWriter  # 也可以使用 pytorch 集成的 tensorboardwriter = SummaryWriter('log') # 配置生成的数据保存的地址
for epoch in range(100):writer.add_scalar('test/squared', np.square(epoch), epoch)writer.close()

  执行上述代码后在本文件更目录下生成一个logs文件,且包含了一个事件文件。

  在pycharm中terminal终端输入:

tensorboard --logdir=logs

  一定要注意起初配置的生成文件保存地址,你在terminal终端中命令的地址要能够访问的到!!!

  输入命令后,会生成一个地址,访问即可。

在这里插入图片描述

  在本实战项目的实际训练中,通过is_write变量来控制是否要进行训练可视化,因此实际的代码如下:

	model.train()...if config.is_write:writer = SummaryWriter(log_dir="{0}/{1}".format(config.tb_log_path, time.strftime('%m-%d_%H.%M', time.localtime())))for epoch in range(config.num_epochs):for i, (input_batch, label_batch) in enumerate(train_iter):...# 验证步...if config.is_write:writer.add_scalar('loss/train', loss.item(), total_batch)...if config.is_write:writer.close()

  上述的代码判断is_write变量如果为True,那么会创建一个writer对象,并且每个batch都会记录训练loss。

[4] PEFT微调

  PEFT微调可能是大家最需要的方法。

  官方已经提供了peft库,直接安装即可:

pip install peft 

  不过在使用之前,我们需要明确2点:1)使用PEFT如何修改模型结构?2)使用PEFT如何保存模型?

  其实非常简单!

  使用PEFT如何修改模型结构?

  常规加载模型的代码为:

model = AutoModelForSeq2SeqLM.from_pretrained(conf.pretrained_path)

  使用PEFT方法修改加载模型后的代码为:

from peft import get_peft_model, LoraConfigmodel = AutoModelForSeq2SeqLM.from_pretrained(conf.pretrained_path)
peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
model = get_peft_model(model, peft_config)

【注】朋友们,没看错!使用PEFT只需要在原来的基础上加两行代码即可,其他模型训练阶段的代码完全不需要改变。我这里演示的是使用LoRA进行微调,peft里目前还集成了Prefix Tuning等其他方法。

  使用PEFT如何保存模型?

  我们知道,使用PEFT微调模型是不会修改大模型本身的参数的,因此我们保存只要保存PEFT方法添加的那部分参数即可,PEFT模型保存方法如下:

model.save_pretrained(config.peft_save)

  而传统保存全量参数的代码如下:

torch.save(model.state_dict(), config.save_path)

【注】peft使用起来真的太方便了。

[5] 集成了所有以上功能的模型训练与验证代码

# 获取日志器
logger = logging.getLogger(__name__)def train(config, model, train_iter, dev_iter):start_time = time.time()model.train()optimizer = AdamW(model.parameters(), lr=config.learning_rate)# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0,num_training_steps=len(train_iter) * config.num_epochs)total_batch = 0  # 记录进行到多少batchlosses = []if config.is_write:writer = SummaryWriter(log_dir="{0}/{1}".format(config.tb_log_path, time.strftime('%m-%d_%H.%M', time.localtime())))for epoch in range(config.num_epochs):for i, (input_batch, label_batch) in enumerate(train_iter):optimizer.zero_grad()model_out = model.forward(input_ids=input_batch, labels=label_batch)loss = model_out.losslosses.append(loss.item())loss.backward()optimizer.step()scheduler.step()# 打印步if (total_batch + 1) % config.print_step == 0:avg_loss = np.mean(losses[-config.print_step:])logger.info('Epoch: {} | Step: {} | Train Avg. loss: {:.3f} | lr: {} | Time: {}'.format(epoch + 1,total_batch + 1, avg_loss, scheduler.get_last_lr()[0], get_time_dif(start_time)))# 验证步if (total_batch + 1) % config.checkpoint_step == 0:test_loss = evaluate(model, dev_iter)if config.is_peft:model.save_pretrained(config.peft_save)else:torch.save(model.state_dict(), config.save_path)time_dif = get_time_dif(start_time)logger.info('Test Avg. loss: {:.3f} | Time: {} '.format(test_loss, time_dif))if config.is_write:writer.add_scalar('loss/train', loss.item(), total_batch)model.train()total_batch += 1if config.is_peft:model.save_pretrained(config.peft_save)else:torch.save(model.state_dict(), config.save_path)if config.is_write:writer.close()def evaluate(model, data_iter):model.eval()eval_losses = []with torch.no_grad():for input_batch, label_batch in data_iter:model_out = model.forward(input_ids=input_batch, labels=label_batch)eval_losses.append(model_out.loss.item())return np.mean(eval_losses)

[6] 进行下一篇实战

  【mT5多语言翻译】之六——推理:多语言翻译与第三方接口设计

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1045919.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网络安全指南:安全访问 Facebook 的技巧

在当今数字化时代,网络安全问题越来越受到人们的关注。尤其是在社交媒体平台上,如 Facebook 这样的巨头,用户的个人信息和隐私更容易受到威胁。为了保护自己的在线安全,我们需要采取一些措施来确保在使用 Facebook 时能够安全可靠…

C语言进阶|顺序表

✈顺序表的概念及结构 线性表(linear list)是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使 用的数据结构,常见的线性表:顺序表、链表、栈、队列、字符串.. 线性表在逻辑上是线性结构,也就说是连…

大话设计模式——23.备忘录模式(Memento Pattern)

简介 又称快照模式,在不破坏封装性的前提下,捕获一个对象的内部状态,并且该对象之外保存这个状态。这样以后就可将该对象恢复到原先保存的状态 UML图 应用场景 允许用户取消不确定或者错误的操作,能够恢复到原先的状态游戏存档、…

深度学习架构(CNN、RNN、GAN、Transformers、编码器-解码器架构)的友好介绍。

一、说明 本博客旨在对涉及卷积神经网络 (CNN)、递归神经网络 (RNN)、生成对抗网络 (GAN)、转换器和编码器-解码器架构的深度学习架构进行友好介绍。让我们开始吧!! 二、卷积神经网络…

【动手学深度学习】15_汉诺塔问题

注: 本系列仅为个人学习笔记,学习内容为《算法小讲堂》(视频传送门),通俗易懂适合编程入门小白,需要具备python语言基础,本人小白,如内容有误感谢您的批评指正 汉诺塔(To…

c/c++ |游戏后端开发之skynet

作者眼中的skynet 有一点要说明的是,云风至始也没有公开说skynet专门为游戏开发,换句话,skynet 引擎也可以用于web 开发 贴贴我的笔记 skynet 核心解决什么问题 愿景:游戏服务器能够充分利用多核优势,将不同的业务放在…

【随笔】Git 高级篇 -- 本地栈式提交 rebase | cherry-pick(十七)

💌 所属专栏:【Git】 😀 作  者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! 💖 欢迎大…

QT Creator概览

🐌博主主页:🐌​倔强的大蜗牛🐌​ 📚专栏分类:QT❤️感谢大家点赞👍收藏⭐评论✍️ 目录 一、Qt Creator 概览 ①:菜单栏 ②:模式选择 ③:构建套件选择器…

在图片上画出mask和pred

画出论文中《Variance-aware attention U-Net for multi-organ segmentation》的图1,也就是在原图上画出mask和pred的位置。 新建一个文件夹 然后运行代码: import cv2 import os from os.path import splitext####第一次:把GT&#xff08…

天书奇谈_源码_搭建架设_3D最新天启版_自带假人

本教程仅限学习使用,禁止商用,一切后果与本人无关,此声明具有法律效应!!!! 一. 效果演示 天书奇谈_源码_搭建架设 环境: centos7.6 , 放开所有端口 源码获取 https://…

Unity Pro 2019 for Mac:专业级游戏引擎,助力创意无限延伸!

Unity Pro 2019是一款功能强大的游戏开发引擎,其特点主要体现在以下几个方面: 强大的渲染技术:Unity Pro 2019采用了新的渲染技术,包括脚本化渲染流水线,能够轻松自定义渲染管线,通过C#代码和材料材质&…

Rust面试宝典第1题:爬楼梯

题目 小乐爬楼梯,一次只能上1级或者2级台阶。楼梯一共有n级台阶,请问总共有多少种方法可以爬上楼? 解析 这道题虽然是一道编程题,但实际上更是一道数学题,着重考察应聘者的逻辑思维能力和分析解决问题的能力。 当楼梯只…

华为2024年校招实习硬件-结构工程师机试题(四套)

华为2024年校招&实习硬件-结构工程师机试题(四套) (共四套)获取(WX: didadidadidida313,加我备注:CSDN 华为硬件结构题目,谢绝白嫖哈) 结构设计工程师,结…

HTTP与HTTPS:深度解析两种网络协议的工作原理、安全机制、性能影响与现代Web应用中的重要角色

HTTP (HyperText Transfer Protocol) 和 HTTPS (Hypertext Transfer Protocol Secure) 是互联网通信中不可或缺的两种协议,它们共同支撑了全球范围内的Web内容传输与交互。本文将深度解析HTTP与HTTPS的工作原理、安全机制、性能影响,并探讨它们在现代Web…

泰迪智能科技高职人工智能专业人才培养方案

人工智能行业近年来得到了快速发展,全球科技公司都在竞相投入人工智能的研发,从硅谷到北京,都在人工智能上取得了显著的进步。人工智能已经从学术研究转变为影响制造业、医疗保健、交通运输和零售等多个行业的关键因素。我国政策的积极推动下…

CentOS 7与MySQL 5.7.25主从复制实践

本文主要记录mysql主从复制的详细步骤,如果你还没来得及安装MySQL请参考CentOS 7实战:轻松实现MySQL 5.7.25的tar包离线安装 ProcessOn源文件地址 主从复制应用场景: 从服务器作为主服务器的实时备份主从服务器实现读写分离(主…

南京航空航天大学-考研科目-513测试技术综合 高分整理内容资料-01-单片机原理及应用分层教程-单片机有关常识部分

系列文章目录 高分整理内容资料-01-单片机原理及应用分层教程-单片机有关常识部分 文章目录 系列文章目录前言总结 前言 单片机的基础内容繁杂,有很多同学基础不是很好,对一些细节也没有很好的把握。非常推荐大家去学习一下b站上的哈工大 单片机原理及…

Java快速入门系列-9(Spring框架与Spring Boot —— 深度探索及实践指南)

第九章:Spring框架与Spring Boot —— 深度探索及实践指南 9.1 Spring框架概述9.2 Spring IoC容器9.3 Spring AOP9.4 Spring MVC9.5 Spring Data JPA/Hibernate9.6 Spring Boot快速入门与核心特性9.7 Spring Boot的自动配置与启动流程详解9.8 创建RESTful服务与数据库交互实践…

RTSP/Onvif视频安防监控平台EasyNVR调用接口返回匿名用户名和密码的原因排查

视频安防监控平台EasyNVR可支持设备通过RTSP/Onvif协议接入,并能对接入的视频流进行处理与多端分发,包括RTSP、RTMP、HTTP-FLV、WS-FLV、HLS、WebRTC等多种格式。平台拓展性强、支持二次开发与集成,可应用在景区、校园、水利、社区、工地等场…

03-JAVA设计模式-组合模式

组合模式 什么是组合模式 组合模式(Composite Pattern)允许你将对象组合成树形结构以表示“部分-整体”的层次结构,使得客户端以统一的方式处理单个对象和对象的组合。组合模式让你可以将对象组合成树形结构,并且能像单独对象一…