AI夏令营笔记——任务2

news/2024/5/10 7:31:20/文章来源:https://blog.csdn.net/m0_51940505/article/details/132439140

文章目录

  • 任务说明
  • 实现思路
  • 优化方向

任务说明

任务要求与任务1一样:

从论文标题、摘要作者等信息,判断该论文是否属于医学领域的文献。
可以将任务看作是一个文本二分类任务。机器需要根据对论文摘要等信息的理解,将论文划分为医学领域的文献和非医学领域的文献两个类别之一。

实现思路

使用预训练的大语言模型进行建模,在这里使用的是BERT。具体步骤如下:

  1. 数据预处理:首先,对文本数据进行预处理,包括文本清洗(如去除特殊字符、标点符号)、分词等操作。可以使用常见的NLP工具包(如NLTK或spaCy)来辅助进行预处理。
  2. 构建训练所需的dataset:构建Dataset类时,需要定义三个方法__init__,getitemlen,其中__init__方法完成类初始化,__getitem__要求返回返回内容和label,__len__方法返回数据长度
  3. 构造Dataloader:在其中完成对句子进行编码、填充、组装batch等动作:
  4. 定义预测模型利用预训练的BERT模型来解决文本二分类任务,我们将使用BERT模型编码中的[CLS]向量来完成二分类任务

[CLS]就是classification的意思,可以理解为用于下游的分类任务。

在这里插入图片描述
本任务的baseline如下:

#import 相关库
#导入前置依赖
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
# 用于加载bert模型的分词器
from transformers import AutoTokenizer
# 用于加载bert模型
from transformers import BertModel
from pathlib import Pathbatch_size = 8
# 文本的最大长度
text_max_length = 128
# 总训练的epochs数,我只是随便定义了个数
epochs = 100
# 学习率
lr = 3e-5
# 取多少训练集的数据作为验证集
validation_ratio = 0.1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 每多少步,打印一次loss
log_per_step = 50# 数据集所在位置
dataset_dir = Path("")
os.makedirs(dataset_dir) if not os.path.exists(dataset_dir) else ''# 模型存储路径
model_dir = Path("./model/bert_checkpoints")
# 如果模型目录不存在,则创建一个
os.makedirs(model_dir) if not os.path.exists(model_dir) else ''print("Device:", device)# 读取数据集,进行数据处理pd_train_data = pd.read_csv('train.csv')
pd_train_data['title'] = pd_train_data['title'].fillna('')
pd_train_data['abstract'] = pd_train_data['abstract'].fillna('')test_data = pd.read_csv('testB.csv')
test_data['title'] = test_data['title'].fillna('')
test_data['abstract'] = test_data['abstract'].fillna('')
pd_train_data['text'] = pd_train_data['title'].fillna('') + ' ' +  pd_train_data['author'].fillna('') + ' ' + pd_train_data['abstract'].fillna('')+ ' ' + pd_train_data['Keywords'].fillna('')
test_data['text'] = test_data['title'].fillna('') + ' ' +  test_data['author'].fillna('') + ' ' + test_data['abstract'].fillna('')+ ' ' + pd_train_data['Keywords'].fillna('')
test_data['Keywords'] = test_data['title'].fillna('')# 从训练集中随机采样测试集
validation_data = pd_train_data.sample(frac=validation_ratio)
train_data = pd_train_data[~pd_train_data.index.isin(validation_data.index)]# 构建Dataset
class MyDataset(Dataset):def __init__(self, mode='train'):super(MyDataset, self).__init__()self.mode = mode# 拿到对应的数据if mode == 'train':self.dataset = train_dataelif mode == 'validation':self.dataset = validation_dataelif mode == 'test':# 如果是测试模式,则返回内容和uuid。拿uuid做target主要是方便后面写入结果。self.dataset = test_dataelse:raise Exception("Unknown mode {}".format(mode))def __getitem__(self, index):# 取第index条data = self.dataset.iloc[index]# 取其内容text = data['text']# 根据状态返回内容if self.mode == 'test':# 如果是test,将uuid做为targetlabel = data['uuid']else:label = data['label']# 返回内容和labelreturn text, labeldef __len__(self):return len(self.dataset)
train_dataset = MyDataset('train')
validation_dataset = MyDataset('validation')
train_dataset.__getitem__(0)#获取Bert预训练模型
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
#接着构造我们的Dataloader。
#我们需要定义一下collate_fn,在其中完成对句子进行编码、填充、组装batch等动作:
def collate_fn(batch):"""将一个batch的文本句子转成tensor,并组成batch。:param batch: 一个batch的句子,例如: [('推文', target), ('推文', target), ...]:return: 处理后的结果,例如:src: {'input_ids': tensor([[ 101, ..., 102, 0, 0, ...], ...]), 'attention_mask': tensor([[1, ..., 1, 0, ...], ...])}target:[1, 1, 0, ...]"""text, label = zip(*batch)text, label = list(text), list(label)# src是要送给bert的,所以不需要特殊处理,直接用tokenizer的结果即可# padding='max_length' 不够长度的进行填充# truncation=True 长度过长的进行裁剪src = tokenizer(text, padding='max_length', max_length=text_max_length, return_tensors='pt', truncation=True)return src, torch.LongTensor(label)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
inputs, targets = next(iter(train_loader))
print("inputs:", inputs)
print("targets:", targets)#定义预测模型,该模型由bert模型加上最后的预测层组成
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 加载bert模型self.bert = BertModel.from_pretrained('bert-base-uncased', mirror='tuna')# 最后的预测层self.predictor = nn.Sequential(nn.Linear(768, 256),nn.ReLU(),nn.Linear(256, 1),nn.Sigmoid())def forward(self, src):""":param src: 分词后的推文数据"""# 将src直接序列解包传入bert,因为bert和tokenizer是一套的,所以可以这么做。# 得到encoder的输出,用最前面[CLS]的输出作为最终线性层的输入outputs = self.bert(**src).last_hidden_state[:, 0, :]# 使用线性层来做最终的预测return self.predictor(outputs)
model = MyModel()
model = model.to(device)#定义出损失函数和优化器。这里使用Binary Cross Entropy:
criteria = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)# 由于inputs是字典类型的,定义一个辅助函数帮助to(device)
def to_device(dict_tensors):result_tensors = {}for key, value in dict_tensors.items():result_tensors[key] = value.to(device)return result_tensors#定义一个验证方法,获取到验证集的精准率和loss
def validate():model.eval()total_loss = 0.total_correct = 0for inputs, targets in validation_loader:inputs, targets = to_device(inputs), targets.to(device)outputs = model(inputs)loss = criteria(outputs.view(-1), targets.float())total_loss += float(loss)correct_num = (((outputs >= 0.5).float() * 1).flatten() == targets).sum()total_correct += correct_numreturn total_correct / len(validation_dataset), total_loss / len(validation_dataset)# 首先将模型调成训练模式
model.train()# 清空一下cuda缓存
if torch.cuda.is_available():torch.cuda.empty_cache()# 定义几个变量,帮助打印loss
total_loss = 0.
# 记录步数
step = 0# 记录在验证集上最好的准确率
best_accuracy = 0# 开始训练
for epoch in range(epochs):model.train()for i, (inputs, targets) in enumerate(train_loader):# 从batch中拿到训练数据inputs, targets = to_device(inputs), targets.to(device)# 传入模型进行前向传递outputs = model(inputs)# 计算损失loss = criteria(outputs.view(-1), targets.float())loss.backward()optimizer.step()optimizer.zero_grad()total_loss += float(loss)step += 1if step % log_per_step == 0:print("Epoch {}/{}, Step: {}/{}, total loss:{:.4f}".format(epoch+1, epochs, i, len(train_loader), total_loss))total_loss = 0del inputs, targets# 一个epoch后,使用过验证集进行验证accuracy, validation_loss = validate()print("Epoch {}, accuracy: {:.4f}, validation loss: {:.4f}".format(epoch+1, accuracy, validation_loss))torch.save(model, model_dir / f"model_{epoch}.pt")# 保存最好的模型if accuracy > best_accuracy:torch.save(model, model_dir / f"model_best.pt")best_accuracy = accuracy#加载最好的模型,然后进行测试集的预测
model = torch.load(model_dir / f"model_best.pt")
model = model.eval()test_dataset = MyDataset('test')
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)results = []
for inputs, ids in test_loader:outputs = model(inputs.to(device))outputs = (outputs >= 0.5).int().flatten().tolist()ids = ids.tolist()results = results + [(id, result) for result, id in zip(outputs, ids)]
test_label = [pair[1] for pair in results]
test_data['label'] = test_label
test_data['Keywords'] = test_data['title'].fillna('')
test_data[['uuid', 'Keywords', 'label']].to_csv('submit_task4.csv', index=None)

优化方向

  1. 换模型:不同的模型的效果是不同的,可以多尝试不同的模型,然后再选择一个最优的。
  2. 调参优化:如果模型效果不理想,可以尝试调整超参数以获得更好的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_160317.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android SDK 上手指南||第四章 应用程序结构

第四章 应用程序结构 本教程将主要以探索与了解为主要目的,但后续的系列文章则将进一步带大家深入学习如何创建用户界面、响应用户交互操作以及利用Java编排应用逻辑。我们将专注于大家刚刚开始接触Android开发时最常遇到的项目内容,但也会同时涉及一部…

二级评论列表功能

一:需求场景 我的个人网站留言列表在开发时,因为本着先有功能的原则。留言列表只有一级,平铺的。 当涉及多人回复,或者两个人多次对话后, 留言逻辑看着非常混乱。如下图 于是,我就打算将平铺的列表&#…

关于模板的大致认识【C++】

文章目录 函数模板函数模板的原理函数模板的实例化模板参数的匹配原则 类模板类模板的定义格式类模板的实例化 非类型模板参数typename 与class模板的特化函数模板特化类模板特化全特化偏特化 模板的分离编译 函数模板 函数模板的原理 template <typename T> //模板参数…

考研C语言进阶题库——更新41-50题

目录 41.编写程序要求输出整数a和b若a和b的平方和大于100&#xff0c;则输出a和b的平方和&#xff0c;否则输出a和b的和 42.现代数学的著名证明之一是Georg Cantor证明了有理数是可枚举的。他是用下面这一张表来证明这一命题的&#xff1a;第一项是1/1&#xff0c;第二项是是…

【Unity细节】Unity制作汽车时,为什么汽车会被弹飞?为什么汽车会一直抖动?

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! 本文由 秩沅 原创 &#x1f636;‍&#x1f32b;️收录于专栏&#xff1a;unity细节和bug &#x1f636;‍&#x1f32b;️优质专栏 ⭐【…

0基础入门代码审计-2 Fortify初探

0x01 序言 目前又加入一位新童鞋了&#xff0c;最近将会再加入cs相关的专栏&#xff0c;都是以基础为主&#xff0c;毕竟太复杂的东西&#xff0c;能看懂的人太少。 0x02 准备工具 1、Fortify 2、需要审计的源码 0x03 Fortify的简单使用 1、 1、在开始菜单栏中找到Audit Wo…

Matlab绘制灰度直方图

直方图是根据灰图像绘制的&#xff0c;而不是彩色图像通。查看图像直方图时候&#xff0c;需要先确定图片是否为灰度图&#xff0c;使用MATLAB2019查看图片是否是灰度图片&#xff0c;在读取图片后在MATLAB界面的工作区会显示读取的图像矩阵&#xff0c;如果是&#xff0c;那么…

代码随想录算法训练营(回溯总结篇)

回溯也可以说是暴力搜索&#xff08;最多剪枝一下&#xff09;。回溯是递归的副产品&#xff0c;只要有递归就会有回溯。 一.分类 1.组合问题 &#xff08;1&#xff09;按组合元素的个数 &#xff08;2&#xff09;按组合元素的总和 有重复元素 同一元素可以重复选&#x…

软考高级系统架构设计师(一)计算机硬件

【原文链接】软考高级系统架构设计师&#xff08;一&#xff09;计算机硬件 1.1 计算机硬件组成 1.1.1 计算机的基本硬件组成 运算器控制器存储器输入设备输出设备 1.1.2 中央处理单元&#xff08;CPU&#xff09; 中央处理单元&#xff08;CPU&#xff09;的组成 运算器…

基于AVR128单片机世界电子时钟的设计

一、系统方案 上电初始化完成系统初始化&#xff0c;液晶滚动显示北京、莫斯科、东京、伦敦、巴黎、纽约等六个城市的标准时间&#xff0c;显示的内容包括地区名及相应地区的年、月、日、星期、时、分、秒。 使用K1按键控制滚动显示或稳定显示某个地区的时间。 使用K3、K4、K5按…

Fastadmin框架 聚合数字生活抵扣卡系统v2.8.6

【2.8.6更新公告】 1.【优化】优化已知问题。 2.【新增 】新增区县影院。

【ESP系列】ESP01S官方MQTT案例实验

前言 偶然发现安信可官网有ESP01S和STM32连接TCP和MQTT的案例。弄了一两天&#xff0c;把我使用的流程在这里记录下。MQTT的固件一定要烧录进去&#xff0c;默认固件是没有MQTT相关的AT指令的。 环境 Keli5&#xff0c;STM32F103C8T6 官方Keil工程链接&#xff1a;ESP8266的S…

leetcode 309. 买卖股票的最佳时机含冷冻期

2023.8.22 本题是买卖股票系列 冷冻期。 由于引入了冷冻期&#xff0c;并且这个冷冻期是在卖出股票才会出现&#xff0c;因此我dp数组设置了四种状态&#xff1a; 状态一&#xff1a;持有股票。状态二&#xff1a;不持有股票&#xff1a; 之前就卖了&#xff0c;所以今天不处…

使用rook搭建Ceph集群

宿主机&#xff1a; MacBook Pro&#xff08;Apple M2 Max&#xff09; VMware Fusion Player 版本 13.0.2 VM软硬件&#xff1a; ubuntu 22.04.2 4核 CPU&#xff0c;5G 内存&#xff0c;40G硬盘 *每台机器分配硬件资源很重要&#xff0c;可以适当超过宿主机的资源量&am…

机器学习在大数据分析中的应用

文章目录 机器学习在大数据分析中的原理机器学习在大数据分析中的应用示例预测销售趋势客户细分和个性化营销 机器学习在大数据分析中的前景和挑战前景挑战 总结 &#x1f389;欢迎来到AIGC人工智能专栏~探索机器学习在大数据分析中的应用 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒&…

Vim学习(四)——命令使用技巧

命令模式 打开文本默认模式&#xff0c;按**【ESC】**重新进入 【/关键字】&#xff1a;搜索匹配关键字 G&#xff1a;最后一行 gg&#xff1a;第一行 hjkl:左下右上 yy: 复制一行 dd&#xff1a;删除一行 p:粘贴 u: 撤销插入模式 按**【i / a / o】**键均可进入文本编辑模式…

如何搭建关键字驱动自动化测试框架?

前言 那么这篇文章我们将了解关键字驱动测试又是如何驱动自动化测试完成整个测试过程的。关键字驱动框架是一种功能自动化测试框架&#xff0c;它也被称为表格驱动测试或者基于动作字的测试。关键字驱动的框架的基本工作是将测试用例分成四个不同的部分。首先是测试步骤&#…

在mac下,使用Docker安装达梦数据库

前言&#xff1a;因为业务需要安装达梦数据库 获取官网下载tar包&#xff08;达梦官网的下载页面https://www.dameng.com/list_103.html&#xff09;&#xff0c;或者通过命令 一、下载tar包 命令下载&#xff1a;wget -O dm8_docker.tar -c https://download.dameng.com/eco/…

【无标题】 欢迎使用Markdown编辑器

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题&#xff0c;有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…

回归预测 | MATLAB实现WOA-RF鲸鱼优化算法优化随机森林算法多输入单输出回归预测(多指标,多图)

回归预测 | MATLAB实现WOA-RF鲸鱼优化算法优化随机森林算法多输入单输出回归预测&#xff08;多指标&#xff0c;多图&#xff09; 目录 回归预测 | MATLAB实现WOA-RF鲸鱼优化算法优化随机森林算法多输入单输出回归预测&#xff08;多指标&#xff0c;多图&#xff09;效果一览…