【简单、高效、性能好】SetFit:无需Prompts的高效小样本学习

news/2024/5/7 13:46:49/文章来源:https://blog.csdn.net/u011239443/article/details/127852349

重磅推荐专栏: 《Transformers自然语言处理系列教程》
手把手带你深入实践Transformers,轻松构建属于自己的NLP智能应用!

1. 概要

使用预训练语言模型的小样本学习(处理只有少量标签或没有标签的数据)已成为比较普遍的解决方案。
SetFit:一种用于对 Sentence Transformers 进行少量微调的高效框架。SetFit 用很少的标记数据实现了高精度——例如,在客户评论 (CR) 情绪数据集上每个类只有 8 个标记样本,SetFit 在 3k 个样本的完整训练集上与微调 RoBERTa Large 相比,如图1-1所示,具有竞争力表现:
图1-1,与标准微调相比,SetFit 的样本效率和抗噪能力要高得多
与其他小样本学习方法相比,SetFit 有几个独特的特点:

  • 没有提示(prompts )或语言器(verbalisers):当前的小样本微调技术需要手工制作的提示(prompts )或语言器(verbalisers)将样本转换为适合底层语言模型的格式。SetFit 通过直接从少量带标签的文本示例生成丰富的embeddings 来完全免除prompts 。

  • 训练速度快:SetFit 不需要像 T0 或 GPT-3 这样的大型模型来实现高精度。因此,训练和运行推理的速度通常快一个数量级(或更多)。

  • 多语言支持:SetFit 可以与 Hub 上的任何 Sentence Transformer 一起使用,这意味着你可以通过简单地微调多语言checkpoint来对多种语言的文本进行分类。

  • 论文: https://arxiv.org/pdf/2209.11055.pdf

  • 代码:https://github.com/huggingface/setfit

2. 原理

SetFit原理比较简单,它设计考虑了效率和简单性。SetFit 首先在少量标记示例(通常每个类 8 或 16 个)上微调 Sentence Transformer 模型。接下来是在微调的 Sentence Transformer 生成的embeddings上训练分类器头。SetFit 利用 Sentence Transformers 的能力基于成对的句子生成密集embeddings 。如图2-1所示:
图2-1,SetFit的两阶段训练流程
图2-2,句子对生成伪代码

  • 在初始微调阶段,它通过对比训练利用有限的标记输入数据,其中正负对由类内和类外选择创建,如图2-2所示。然后,Sentence Transformer 模型对这些对(或三元组)进行训练,并为每个样本生成密集向量。
  • 在第二步中,分类头使用各自的类标签对编码embeddings进行训练。在推理时,未见过的样本通过微调的 Sentence Transformer,生成一个embedding ,当将其送到分类头时,输出一个类标签预测结果。

只需将基本的 Sentence Transformer 模型切换为多语言模型,SetFit 就可以在多语言环境中无缝运行。

3. 实验

3.1 效果表现

虽然基于比现有的少样本方法小得多的模型,但 SetFit 在各种基准测试中的表现与sota的少样本方法相当或更好。如图3-1所示,在RAFT(一个 few-shot 分类基准)上,具有 3.55 亿个参数的 SetFit Roberta 优于 PET 和 GPT-3。它仅仅在人类平均表现和 110 亿个参数 T-few(这个模型的大小是 SetFit Roberta 的 30 倍) 水平之下。SetFit 在 11 项 RAFT 任务中的 7 项上也优于人类基线
图3-1,RAFT 排行榜上的突出方法(截至 2022 年 9 月)
在其他数据集上,SetFit 在各种任务中表现出稳健性。如下图3-2所示,每个类只有 8 个示例,它基本上优于 PERFECT、ADAPET 和微调的 vanilla transformer。SetFit 也取得了与 T-Few 3B 相当的结果,尽管它无需提示且体积小 27 倍
图3-2,在 3 个分类数据集上将 Setfit 性能与其他方法进行比较

3.2 训练和推理速度

由于 SetFit 使用相对较小的模型实现了高精度,因此它的训练速度非常快,而且成本要低得多。例如,使用 8 个标记示例在 NVIDIA V100 上训练 SetFit 仅需 30 秒,成本为 0.025 美元。相比之下,训练 T-Few 3B 需要 NVIDIA A100,耗时 11 分钟,同一实验的成本约为 0.7 美元——高出 28 倍。事实上,SetFit 可以像 Google Colab 上的那样在单个 GPU 上运行,你甚至可以在几分钟内在 CPU 上训练 SetFit!如图3-3所示,SetFit带来了提速,模型性能却与T-Few 3B相当。预测和蒸馏 SetFit 模型也可以获得类似的收益,可以带来 123 倍的加速!
图3-3,比较 T-Few 3B 和 SetFit (MPNet) 的训练成本和平均性能,每个类有 8 个标记样本

4. 实践:零样本文本分类

SetFit还可以做零样本文本分类。我们需要做的第一件事是创建一个合成样本的虚拟数据集。我们可以通过将 add_templated_examples() 函数来完成此操作。此函数需要一些主要内容:

  • 用于分类的候选标签列表。 我们将在此处使用参考数据集中的标签。
  • 用于生成示例的模板。 默认情况下,它是“This sentence is {}”,其中{}将由候选标签名称之一填充
  • 样本量 N,这将为每个类创建 N 个合成示例。 作者发现 N=8 通常效果最好。
dataset_id = "emotion"
model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
from datasets import load_dataset
reference_dataset = load_dataset(dataset_id)# 从“label”列中提取 ClassLabel 特征
label_features = reference_dataset["train"].features["label"]
# 用于分类的标签名称
candidate_labels = label_features.names
candidate_labels
['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
from datasets import Dataset
from setfit import add_templated_examples# 用合成样本填充的虚拟数据集
dummy_dataset = Dataset.from_dict({})
train_dataset = add_templated_examples(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)
train_dataset

由于我们的数据集有 6 个类别,我们选择的样本大小为 8,因此我们的合成数据集包含 6×8=48 个样本。

Dataset({features: ['text', 'label'],num_rows: 48
})

我们看几个例子:

train_dataset.shuffle()[:3]
{'text': ['This sentence is love','This sentence is fear','This sentence is joy'],'label': [2, 4, 1]}

用这样虚拟数据集来微调模型,在预测看看效果:

from setfit import SetFitModelmodel = SetFitModel.from_pretrained(model_id)from setfit import SetFitTrainertrainer = SetFitTrainer(model=model,train_dataset=train_dataset,eval_dataset=reference_dataset["test"]
)trainer.train()
zeroshot_metrics = trainer.evaluate()
zeroshot_metrics
{'accuracy': 0.5345}

我们在尝试一下用 Hugging Face 的 zero-shot-classification:

from transformers import pipelinepipe = pipeline("zero-shot-classification", device=0)zeroshot_preds = pipe(reference_dataset["test"]["text"], batch_size=16, candidate_labels=candidate_labels)

zero-shot-classification pipeline 默认用的是 facebook/bart-large-mnli。注意,该方法 生成预测结果所需的时间比 SetFit 长将近 5 倍! 好的,那么它的性能如何?

preds = [label_features.str2int(pred["labels"][0]) for pred in zeroshot_preds]import evaluatemetric = evaluate.load("accuracy")
transformers_metrics = metric.compute(predictions=preds, references=reference_dataset["test"]["label"])
transformers_metrics

与 SetFit 相比,这种方法的性能要差得多:

{'accuracy': 0.3765}

看来 SetFit 真的是——即简单,又高效,还性能好 !666666666666666…

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_226999.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Zlibrary已死,找了一个替代品,找了一个替代品免费的电子书下载平台...

大家好,我是鸟哥。一个半路出家的程序员。 提到Zlibrary,想必大家都不陌生吧。全球最大的数字图书馆,截止被封前共收录了591万本书,7751万篇文章,并且还在不断的增加中,关键是可以免费下载。 反正我是很熟悉…

智能计量系统配套设备有哪些

智能计量系统配套设备 地磅区域安装配套设备包含:微波定位仪、视频监控、道闸、LED显示屏、车号识别、语音对讲、音响设备、红绿灯、刷卡机箱、雷达、补光灯。 硬件设备 1、微波定位仪:通过微波定位仪设备,可以判断车辆是否完全上磅。 2、…

C++11(一)

🧸🧸🧸各位大佬大家好,我是猪皮兄弟🧸🧸🧸 文章目录一、列表初始化initializer_list二、声明1.auto2.decltype3.nullptr三、C11 STL中的变化1.array2.forward_list3.STL其他变化四、C关键字新功…

【三维重建补充知识-0】视差、深度概念及其转换

一、基本概念 把手指放在眼前,分别闭上左、右眼,我们会发现手指与后边物体的相对位置是不同的,也即两眼所识别的两幅图像之间存在视觉差异,我们通过“视差”这一概念来表示这种差别。 该过程也可以通过两个处于同一平面的相机来模…

C++ 之 移动构造函数

1、左值和右值 C( 包括 C) 中所有的表达式和变量要么是左值,要么是右值。 通俗的左值的定义就是非临时对象,那些可以在多条语句中使用的对象,表达式结束后依然存在的持久化对象,所有的具名变量或者对象都是左值。右值是指临时的…

设置渐变边框色

如上图所示,需设置渐变边框色,左右边框颜色固定,上边框从左到右开始渐变,下边框从右到左开始渐变。 思考了很久,如果看作是一个div,则需要用到 border-image属性设置渐变色。也可以看作是两个div&#xff0…

CISAW信息安全保障人员认证考试难吗?

CISAW信息安全保障人员认证,作为信息安全行业相当热门的证书之一,其持证人数已超50%,在信息安全行业内占有一席之地,很多报考人都比较关心CISAW考试难不难?能通过吗?那接下来说一说CISAW证书考不好考&#…

常见的网络协议

目录 一、TCP/IP协议簇 二、网络设备与五层模型对应关系: 三、常用网络协议总结(TCP/IP协议簇) 四、应用层服务协议 五、传输层协议组 TCP_UDP 六、网络层协议 IP_ICMP_ARP 七、物理层协议 MAC子层协议 一、TCP/IP协议簇 OSI七层模型…

IBM MQ 故障诊断(一)

说明:本文主要是针对运维人员的手册。前面部分主要是应用三板斧的方式,后面的步骤可能会发散和具体深入一些。不过也不是严格的划分,读者就当看一遍杂文的方式来看待此文吧。 一,队列管理器的启停 QMGR的启停是故障诊断中遇到最…

[附源码]SSM计算机毕业设计线上图书销售管理系统JAVA

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

用huggingface.transformers在文本分类任务(单任务和多任务场景下)上微调预训练模型

诸神缄默不语-个人CSDN博文目录 transformers官方文档:https://huggingface.co/docs/transformers/index AutoModel文档:https://huggingface.co/docs/transformers/v4.23.1/en/model_doc/auto#transformers.AutoModel AutoTokenizer文档:ht…

Java#数据结构----2

目录 一.数据结构(树) 二.二叉树(任意节点的度<2) 二叉查找树又称为二叉排序树/二叉搜索树 平衡二叉树 平衡二叉树的旋转机制 三.红黑树 一.数据结构(树) 基本概念: 度: 每一个节点的子节点数量 树高: 树的总层数 根节点: 最顶层的节点 左子节点: 左下方的节点 右子节…

优维低代码:Redirect 路由重定向If 条件渲染

优维低代码技术专栏&#xff0c;是一个全新的、技术为主的专栏&#xff0c;由优维技术委员会成员执笔&#xff0c;基于优维7年低代码技术研发及运维成果&#xff0c;主要介绍低代码相关的技术原理及架构逻辑&#xff0c;目的是给广大运维人提供一个技术交流与学习的平台。 连载…

2022年超实用的推特营销策略

Twitter推广需知的13条基础知识&#xff1a; 1、Twitter日活用户达1亿 2、Twitter月活用户3.25亿 3、Twitter广告价格比其他渠道便宜33% 4、每天产生5亿条推文 5、Twitter推广能够提高29%的线下交易 6、37%的Twitter用户在18到29岁之间 7、86%的带链接推文会比普通推文效…

Cerebral Cortex:调节γ振荡可以促进大脑连接性而改善认知障碍

摘要 老年痴呆症造成了巨大的全球经济负担&#xff0c;但目前还缺乏有效的治疗方法。最近的研究表明&#xff0c;脑电活动的伽马波段波&#xff0c;特别是40赫兹振荡&#xff0c;与高阶认知功能密切相关&#xff0c;可以激活小胶质细胞清除淀粉样蛋白&#xff0d;β沉积。本研究…

Flowable 中的网关、流程变量以及历史流程

今天这篇文章&#xff0c;松哥和大家梳理一下 Flowable 中的网关、流程变量以及历史流程的玩法。 1. 三大网关 Flowable 中网关类型其实也不少&#xff0c;常见的主要有三种类型&#xff0c;分别是&#xff1a; 排他网关并行网关包容网关 这三个里边最常用的当然就是排他网关…

Cesium中的DataSource和Entity关系

本章主要探讨一下Cesium中的DataSource和Entity。 介绍 首先简单说一下Entity与Primitive。 Cesium为开发者提供了丰富的图形绘制和空间数据管理的API&#xff0c;可以分为两类&#xff0c;一类是面向图形开发人员的低层次API&#xff0c;通常被称为Primitive API&#xff0…

连续时间系统的时域分析

一.微分方程的求解 1.求微分方程的齐次解 &#xff08;1&#xff09;写出特征方程并求解 2.写出齐次解 2.求微分方程的特解 已知 &#xff08;1&#xff09;根据表2-2&#xff0c;写出特解函数 ​​​​​​​ &#xff08;2&#xff09;带入并求解 3.完全解 二.微分方…

小杨哥陷入打假风波,会变成下一个辛巴吗?

最近&#xff0c;网红疯狂小杨哥频繁登上热搜。最初的起因是他花了1亿元在合肥一家高科技公司购买了5万多平方米的房产&#xff0c;作为他名下公司的全球总部&#xff0c;由此带来了争议。 据了解&#xff0c;该物业总建筑面积为53874.33平方米&#xff0c;包括1个生产综合体、…

使用扩展有效对齐 SwiftUI 内容,创建自定义 SwiftUI 方法以快速对齐项目并使您的代码看起来简洁明了(教程含源码)

在开发 iOS 应用程序时,对齐内容可能是一个耗时的过程。如果应用程序有多个屏幕,则需要在不同的地方完成这件事,并可能导致看起来杂乱无章的视图。 作为一个始终致力于让我的代码看起来简单和流线型的人,实现目标所需的大量Spacer()元素常常让我恼火,这就是为什么当我发…