Pytorch 中Label Smoothing CrossEntropyLoss实现

news/2024/5/20 0:50:44/文章来源:https://blog.csdn.net/flyingluohaipeng/article/details/128060040

一. 前言

一般情况下我们都是直接调用Pytorch自带的交叉熵损失函数计算loss,但涉及到魔改以及优化时,我们需要自己动手实现loss function,在这个过程中如果能对交叉熵损失的代码实现有一定的了解会帮助我们写出更优美的代码。

其次是标签平滑这个trick通常简单有效,只需要改改损失函数既可带来性能上的提升,通常与交叉熵配合食用。

因此,本文基于这两个出发点,介绍基于Pytorch框架下的交叉熵损失实现以及标签平滑的实现。

二. CrossEntropyLoss

相信大家对于如何计算交叉熵已经非常熟悉,常规步骤是①计算softmax得到各类别置信度;②计算交叉熵损失。但其实从Pytorch的官方文档可以看出,还有更一步到位的方法,如下:
CE
这避免了softmax的计算。

三. 代码实现

class CELoss(nn.Module):''' Cross Entropy Loss'''def __init__(self):super().__init__()def forward(self, pred, target):''' Args:pred: prediction of model output    [N, M]target: ground truth of sampler [N]'''eps = 1e-12# standard cross entropy lossloss = -1.*pred.gather(1, target.unsqueeze(-1)).reshape(-1,1) + torch.log(torch.exp(pred+eps).sum(dim=1)).reshape(-1,1)return loss.mean()

具体细节参考我前面的文章 Pytorch中CrossEntropyLoss()详解。

四. Label Smoothing

Label Smoothing也称之为标签平滑,其实是一种防止过拟合的正则化方法。传统的分类loss采用softmax loss,先对全连接层的输出计算softmax,视为各类别的置信度概率,再利用交叉熵计算损失。
Label Smooth
Label Smooth

在这个过程中尽可能使得各样本在正确类别上的输出概率为1,这要使得对应的z值为+∞,这拉大了其与其他类别间的距离

现在假设一个多分类任务标签是[1,0,0],如果它本身的label的出现了问题,这对模型的伤害是非常大的,因为在训练的过程中强行学习一个非本类的样本,并且让其概率非常高,这会影响对后验概率的估计。并且有时候类与类之间的并不是毫无关联,如果鼓励输出的概率间相差过大,这会导致一定程度上的过拟合

因此Label Smoothing的想法是让目标不再是one-hot标签,而是变为如下形式:
Label Smooth
其中ε为一个较小的常数,这使得softmax损失中的概率优目标不再为1和0,同时z值的最优解也不再是正无穷大,而是一个具体的数值。这在一定程度上避免了过拟合,也缓解了错误标签带来的影响。

五. Label Smoothing CrossEntropyLoss实现

基于上一节的交叉熵实现增加标签平滑功能,代码如下:

class CELoss(nn.Module):''' Cross Entropy Loss with label smoothing '''def __init__(self, label_smooth=None, class_num=137):super().__init__()self.label_smooth = label_smoothself.class_num = class_numdef forward(self, pred, target):''' Args:pred: prediction of model output    [N, M]target: ground truth of sampler [N]'''eps = 1e-12if self.label_smooth is not None:# cross entropy loss with label smoothinglogprobs = F.log_softmax(pred, dim=1)	# softmax + logtarget = F.one_hot(target, self.class_num)	# 转换成one-hot# label smoothing# 实现 1# target = (1.0-self.label_smooth)*target + self.label_smooth/self.class_num 	# 实现 2# implement 2target = torch.clamp(target.float(), min=self.label_smooth/(self.class_num-1), max=1.0-self.label_smooth)loss = -1*torch.sum(target*logprobs, 1)else:# standard cross entropy lossloss = -1.*pred.gather(1, target.unsqueeze(-1)).reshape(-1,1) + torch.log(torch.exp(pred+eps).sum(dim=1)).reshape(-1,1)return loss.mean()

实现1采用了 (1.0-self.label_smooth)*target +self.label_smooth/self.class_num 实现,与原始公式不太一样
后续在了解到pytorch的clamp接口后,发现能够利用其能正确实现原公式,见实现2

六. 试验验证

① 交叉熵损失正确率,与标准的交叉熵比较:

	loss1 = nn.CrossEntropyLoss()loss2 = CELoss(label_smooth=None, class_num=3)x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)y = torch.tensor([1, 2])print(loss1(x, y), loss2(x, y))# tensor(0.0018) tensor(0.0018)

② 标签平滑结果展示:

	loss1 = nn.CrossEntropyLoss()loss2 = CELoss(label_smooth=0.05, class_num=3)x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)y = torch.tensor([1, 2])print(loss1(x, y), loss2(x, y))# tensor(0.0018) tensor(0.2352)

另一组结果:

	x = torch.tensor([[0.1, 8, 0.1], [0.1, 0.1, 8]], dtype=torch.float)y = torch.tensor([1, 2])print(loss1(x, y), loss2(x, y))# tensor(0.0007) tensor(0.2641)

分析:拉大模型输出数值间的差距后,原始的交叉熵会变小,而增加了标签平滑的反而变大。这也反映了标签平滑后,并不是概率越接近于1越好,而是接近某个小于1的值,这使得模型的输出不再是越高(+∞)越好。

七. 参考链接

Pytorch:交叉熵损失(CrossEntropyLoss)以及标签平滑(LabelSmoothing)的实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_226295.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【架构设计】作为架构师你应该掌握的画图技术

1.前言 大家知道,架构的过程其实就是建模的过程,那自然离不开架构图。那么,我们先来看几个问题。 (1)什么是架构图? 架构图 架构 图,用图的形式把系统架构展示出来,配上简单的文…

基于C#的校园闲置物品共享系统的开发和实现(Asp.net+Web)

目 录 摘 要 I Abstract II 第1章 绪论 1 1.1选题背景 1 1.1.1校园闲置物品共享系统的开发背景 1 1.1.2学生闲置物品交易活动的现状 1 1.2 校园闲置物品共享系统的研究方向和内容 1 1.2.1研究方向 1 1.2.2研究内容 2 1.3 校园闲置物品共享系统的设计目标 2 1.4 校园闲置物品共…

多云加速云原生数仓生态,华为与 HashData 联合打造方案

多云的兴起,源于用户应用对于基础设施、云服务功能、安全性等的差异化需求,用户希望根据需求将应用、数据因“云”制宜,实现业务的高度灵活性和高效性。这也直接驱动着云原生数据仓库等一批云原生应用的流行,以及存储等基础设施加…

WR | 水源水耐药基因稳定赋存的关键:以致病菌为“源”,群落构建主导菌为“汇”...

第一作者:武冬通讯作者:David W.Graham、杨凯、谢冰通讯单位:华东师范大学生态与环境科学学院,英国纽卡斯尔大学工程学院文章链接:www.sciencedirect.com/science/article/pii/S0043135422013045- 成果简介 -近日&…

【食品加工技术】第五章 烘烤食品加工技术 笔记

【食品加工技术】第五章 烘烤食品加工技术 笔记5.1 焙烤食品概述烘烤食品的分类按发酵和膨化程度分类安装生产工艺分类烘烤食品的原料面粉糖蛋品乳及乳制品膨松剂烘烤设备常用设备恒温设备常用工具5.2 面包加工工艺和关键技术面包的分类面包的发酵原理面包的工艺流程一次发酵二…

HTML CSS个人网页设计与实现——人物介绍丁真(学生个人网站作业设计)

🎉精彩专栏推荐👇🏻👇🏻👇🏻 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业…

iwebsec靶场 SQL注入漏洞通关笔记6- 宽字节注入

系列文章目录 iwebsec靶场 SQL注入漏洞通关笔记1- 数字型注入_mooyuan的博客-CSDN博客 iwebsec靶场 SQL注入漏洞通关笔记2- 字符型注入(宽字节注入)_mooyuan的博客-CSDN博客 iwebsec靶场 SQL注入漏洞通关笔记3- bool注入(布尔型盲注&#…

【学习笔记38】JavaScript中的本地存储

一、localStorage 浏览器的本地存储(永久存储), 打开浏览器存储上之后, 关闭浏览器, 信息还在语法:window.localStorage.setItem(key, value)注意: value的值必须为字符串key的书写符合见名知意 window.localStorage.setItem(ceshi1, 1111111);window.localStorage.…

制霸GitHub热榜的Spring Cloud Alibaba源码笔记,果然是阿里传出的

7年前面试最常问的并且可以顺利拿到高薪的技能是 Dubbo 3年前面试,只要你简历上有Spring Cloud 项目的相关经验,肯定会打动面试官,现在呢?恐怕简历上有Dubbo和简单的Spring Cloud技术和经验是无法让面试官高看你的。 Spring Cloud Alibaba 近几年在受到国内不少开…

深度学习与总结JVM专辑(三):垃圾回收器—G1(图文+代码)

垃圾收集器G1前言概述停顿时间模型内存布局传统内存布局过时了G1实现的几个关键细节问题铺垫知识:跨代引用铺垫知识:记忆集,卡表,卡页铺垫知识:写屏障插眼往下看G1内存模型分区Region卡片Card堆Heap分代模型分代垃圾收…

TensorRT--学习笔记

官方文档是最权威的TensorRT是可以在NVIDIA各种GPU硬件平台下运行的一个C推理框架。利用Pytorch、TF或者其他框架训练好的模型,可以转化为TensorRT的格式,然后利用TensorRT推理引擎去运行我们这个模型,从而提升这个模型在英伟达GPU上运行的速…

这可能是最权威、最全面的Go语言编码风格规范了!

每种编程语言除了固定的语法之外,都会有属于自己的地道的(idiomatic)写法。其实,自然语言也不例外,你想,你用心想想是不是这样。语言的设计者们希望开发人员都能编写统一风格的地道的代码,这样不仅代码可读性好&#x…

Packet Tracer 实验 - 排除多区域 OSPFv3 故障

地址分配表 设备 接口 IPv6 全局单播地址 IPv6 本地链路地址 默认网关 ISP GigabitEthernet0/0 2001:DB8:C1:1::1/64 FE80::C1 不适用 ASBR GigabitEthernet0/0 2001:DB8:C1:1::2/64 FE80::7 不适用 Serial0/0/0 2001:DB8:A8EA:F0A::1 FE80::7 不适用 S…

JSP学习日记

JSP简述 Java Sever Pages----->Java服务器界面 用于前后端结合 jsp为什么淘汰? 由于JSP的前后端耦合性极高,编写代码非常臃肿。前后端的代码放在一起,所以JSP可以看成是已经被淘汰的技术。 为什么还要学jsp? 由于一些公司…

基于遗传算法的自主式水下潜器路径规划问题附Matlab代码

✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。 🍎个人主页:Matlab科研工作室 🍊个人信条:格物致知。 更多Matlab仿真内容点击👇 智能优化算法 …

MFC编辑框控件属性和用法

目录 一、编辑框的属性 1.want return 2.Multiline 3.滚动条 4.添加完效果 二、初始化编辑框内容 三、复制与退出 四、edit control的值类型 五、思维拓展 一、编辑框的属性 默认情况下编辑框edit control 是可以横向无限输入的 1.want return 支持换行,…

自动化项目倍加福测距仪QSM WCS RS485 与西门子S7 200通信

1、程序流程图 2、WCS位置数据处理流程 第一步:设置S7-200的RS485的通讯波特率19.2kbps,通讯格式(8,1,E); 第二步:PLC向WCS发送请求码: A0A1为0,表示读码器地…

《人月神话》(The Mythical Man-Month)1 看清问题的本质:如果我们想解决问题,就必须试图先去理解它...

第一章 焦油坑(The Tar Pit)史前史中,没有比巨兽在焦油坑中垂死挣扎的场面更令人震撼的了。上帝见证着恐龙、猛犸象、剑齿虎在焦油中挣扎。它们挣扎得越是猛烈,焦油纠缠得越紧,没有任何猛兽足够强壮或具有足够的技巧&a…

【C++数据结构】程序性能分析

程序性能分析 2.1 什么是程序性能 程序性能:所谓程序性能(performance of a program)是指运行这个程序所需要的内存和时间的多少。 性能分析:在性能分析(performance analysis)时,采用分析方…

基于粒子群算法的配电网重构研究matlab程序

基于粒子群算法的配电网重构研究matlab程序 参考文献:基于改进灰狼算法的含分布式电源配电网重构研究 (本文未考虑分布式电源) 摘要:使用基本环矩阵编码的智能优化算法在处理配电网重构问题中,通常使用无序的解空间&a…