未归一化导致Dead ReLU的悲剧

news/2024/4/29 14:36:47/文章来源:https://blog.csdn.net/Powerful_Green/article/details/126941883

问题描述

笔者在参考http://zh.gluon.ai/chapter_deep-learning-basics/mlp-scratch.html
实现多层感知机的时候,遇到了一个问题
那就是,如果使用ReLU作为激活函数,模型的准确率非常低(只有0.1)
但是如果把那个网站上的代码下载下来运行,准确率能达到80%
这就很奇怪了,我们使用的训练方法都是随机梯度下降,学习率,网络参数也是一样的,结果却相差很大

问题排查

经过断点调试,笔者发现,网站上的代码执行时,它的损失函数值很小(只有50多),而且下降地很快,而我的代码损失函数有200多,而且一直降不下去
在确认代码没写错之后,自然就怀疑是数据处理的问题
网站上的数据集获取是通过网站作者自己实现的d2l.load_data_fashion_mnist(batch_size)方法,而我是直接调用gdata.vision.FashionMNIST获取的数据集
经过查看源代码,发现网站作者在加载数据集时,用了一个transformer,也就是gdata.vision.transforms.ToTensor(),通过查资料,发现这是一个归一化的操作
所以大致可以确定,这个问题是由于数据没进行归一化

实验验证

在计算前,把features除以255也可以达到归一化效果,把代码修改为

    for X,y in data_iter:#X=X/255with autograd.record():X=X/255y_hat=net(X)l=loss(y_hat,y).sum()l.backward()

之后,准确率正常了

理论分析

经过查阅资料,笔者形成了两种猜想
1.发生了梯度消失的现象
2.发生了Dead ReLU现象

其中梯度消失先暂时排除,因为网络层数太少了,主要考虑是Dead ReLU
关于Dead ReLU,可以参考一下这篇文章
大致原理就是输入的负数太多,导致它们全都被ReLU函数变成了0,从而影响了学习
具体成因可能是输入的数据太大,并且学习率也比较大,导致进入ReLU函数时负数过多,具体是否是这个成因,需要通过实验验证

实验验证

这里的主要思想是,记录每次ReLU函数中被变成0的元素的占比,并绘图观察
完整代码会放在文章最末
这里就只展示最终结果了
(由于篇幅限制,本来有5轮训练的,这里就每种情况只放第一轮训练的数据图了,剩余轮次的结果差不多)

未做归一化,学习率为0.1的情况,其中relu_rate表示的是被ReLU函数变成0的元素的占比
在这里插入图片描述
可以看见,被很快,几乎所有的元素都被ReLU函数变成0了,所以Dead ReLU的猜想可以认为是正确的了

接下来再来看看改进后的图像
首先是学习率调低至0.001时的情况
在这里插入图片描述
这是第一轮,后面几轮也都是在0.8上下震荡,这是一个比较正常的情况了

再看一下进行了归一化,学习率为0.5的图像
在这里插入图片描述
情况和第二种情况类似,也是比较正常

结论

数据如果过多而且过大,且不做归一化,使用ReLU作为激活函数时就可能出现Dead ReLU的情况,而解决办法有降低学习率和进行归一化两种

下期预告

在研究这一问题的时候,笔者还发现几个有趣的现象
在不进行归一化的时候
1.把激活函数换成sigmoid之后,准确率会大幅提升(大概能到50%)
2.把激活函数换成tanh之后,准确率依然很低(大概17%),但是比ReLU要高
下一篇文章会更多的从理论上去分析这两个问题的成因

写在最后

作为刚入门深度学习还不到一个月的萌新,遇到这类问题是挺无语的
路漫漫其修远兮,吾将上下而求索,能从错误中总结经验也是件好事
以上就是个人对这个问题的见解,如果其中有错误,欢迎各位大佬批评指正

完整代码

from mxnet import nd,autograd
from mxnet.gluon import data as gdatafrom mxnet.gluon import loss as glossimport numpy as npfrom matplotlib import pyplot as pltbatch_size = 256def get_data(train=True):train_data = gdata.vision.FashionMNIST(train=train)features = train_data[:][0]features = features.astype("float32")labels = train_data[:][1]return features,labelsdef get_data_iter(train=True,to_tensor=False):features,labels=get_data(train)dataset = gdata.ArrayDataset(features, labels)if(to_tensor):return gdata.DataLoader(dataset.transform_first(gdata.vision.transforms.Compose([gdata.vision.transforms.ToTensor()])),batch_size,shuffle=True)else:return gdata.DataLoader(dataset, batch_size, shuffle=True)def test():test_features,test_labels=get_data(False)y_predict=net(test_features).argmax(axis=1)accuracy=(nd.array(test_labels) == nd.array(y_predict,dtype="float32")).mean().asscalar()print("Accuracy: "+str(accuracy))hidden_num=300
output_num=10w1=nd.random.normal(scale=0.01,shape=(784,hidden_num))
b1=nd.zeros(hidden_num)
w2=nd.random.normal(scale=0.01,shape=(hidden_num,output_num))
b2=nd.zeros(output_num)params=[w1,b1,w2,b2]for param in params:param.attach_grad()relu_rate=0def relu(x):global relu_raterelu_rate=(x<0).mean().asscalar()return nd.sigmoid(x)def net(x):x=x.reshape((-1,784))result=relu(nd.dot(x,w1)+b1)return nd.dot(result,w2)+b2loss=gloss.SoftmaxCrossEntropyLoss()lr=0.5
data_iter=get_data_iter(to_tensor=False)for i in range(5):record_count=(int)(60000/batch_size)graph_x = np.arange(0, record_count, 1)graph_ys = np.zeros((4, record_count))relu_record=np.zeros((record_count,))graph_index = 0for X,y in data_iter:#X=X/255with autograd.record():X=X/255y_hat=net(X)l=loss(y_hat,y).sum()l.backward()j=0for p in params:grad=p.gradp[:]=p-lr*grad/batch_size#record gradient# record relu_rateif(graph_index<record_count):graph_ys[j][graph_index]=grad.sum().asscalar()relu_record[graph_index] = relu_ratej+=1graph_index+=1print("Loss: "+str(l.sum()))test()plt.subplot(2, 2, 1)plt.title("w1")plt.plot(graph_x,graph_ys[0])plt.subplot(2,2,2)plt.title("b1")plt.plot(graph_x,graph_ys[1])plt.subplot(2, 2, 3)plt.title("w2")plt.plot(graph_x,graph_ys[2])plt.subplot(2, 2, 4)plt.title("b2")plt.plot(graph_x,graph_ys[3])#plt.show()plt.savefig("gradient_"+str(i)+".png")plt.clf()plt.title("relu_rate")plt.plot(graph_x, relu_record)#plt.show()plt.savefig("relu_rate_" + str(i) + ".png")plt.clf()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_10454.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习入门四

Octave相关资源官网地址下载地址相关语法运算符变量函数系统命令数据操作数据加载数据保存元素操作元素计算绘图和可视化工具绘图实例常用函数控制语句Octave相关资源 官网地址 官方地址 下载地址 下载地址 相关语法 运算符 %&#xff1a;注释~&#xff1a;表示不等于xo…

自学Python 62 使用urllib 包并获取百度搜索关键词中得到链接

Python 使用urllib 包 文章目录Python 使用urllib 包一、urllib 包介绍二、使用urllib.request模块三、使用urllib.parse模块在计算机网络模型中&#xff0c;Socket套接字编程属于底层网络协议开发的内容。虽然说编写网络程序需要从底层开始构建&#xff0c;但是自行处理相关协…

【图像分类】基于HOG特征结合SVM实现图像分类识别附matlab代码

1 内容介绍 ​为了满足人工智能在目标识别方法中的应用需求,需要具备对海量数据进行智能分类、识别、判读的能力.进一步挖掘了目标特性数据库数据,并将基于HOGSVM的目标识别算法应用于红外目标识别过程中.选择采集到的汽车、直升机、飞机、舰船、无人机等目标,并结合HOG算子与…

【Vite 实践】Vite 库模式能满足你吗?或许你需要统一构建

2022 年本人投入了 Vite 的怀抱&#xff0c;开始参与到 Vite 社区中&#xff0c;陆续开发了一些插件。 Vite 秉承了开箱即用&#xff0c;简化配置的思路&#xff0c;确实显著提升了前端开发体验。 但是在类库模式的构建上却有所欠缺&#xff0c;只能处理单个输入和单输入出的…

个人笔记--数据库理论 01 关系模型介绍——基于《数据库系统概念》第七版

关系模式 关系的例子 关系模型是目前广泛应用的数据模型由表的集合构成 例如 IDnamedpt_namesalary11111JAMCS12345 元组 tuple&#xff1a;表中的一行&#xff0c;元素无所谓属性 attribute : 原子的&#xff0c;不可再分的&#xff0c;要有属性域&#xff0c;如上表的nam…

云原生爱好者周刊:延迟加载任意 OCI 镜像 | 2022-09-13

开源项目推荐 SOCI Snapshotter SOCI Snapshotter 是一个 Containerd Snapshotter 插件&#xff0c;可以延迟加载任意 OCI 镜像&#xff0c;不需要 Stargz Snapshotter 一样构建特殊格式的镜像才能延迟加载。 Authentication Proxy 这个项目使用 YARP (Yet Another Reverse…

Git的认识和使用

目录 一、前置准备 二、git简介 三、gitee.com的基本使用 1.创建仓库(私库和公库) 2.创建文件及文件夹 新建文件夹两种方式 ①​ ② 3.删除 删除文件 删除仓库 四、组长组员的git使用 git clone 查看文件 git status git add git commit git push ## 命令行配置 多个…

葡聚糖-MAL/NHS/N3/Alkyne/SH/Biotin/CHO/OPSS/OH

产品名称&#xff1a; 葡聚糖-马来酰亚胺&#xff0c;葡聚糖-MAL&#xff0c;马来酰亚胺功能化葡聚糖 英文名称&#xff1a;Dextran-MAL PEG分子量可选&#xff1a;350,550,750,1k&#xff0c;2k&#xff0c;3.4k&#xff0c;5k&#xff0c;10k&#xff0c;20k&#xff08;可…

[仅需1步]企业微信群机器人[0基础接入][java]

[仅需1步]企业微信群机器人[0基础接入][java]背景介绍使用测试项目背景 公司需要把日常的服务器错误抛到企业微信群中,我正好记录下使用企业微信群机器人… 介绍 企业微信群机器人 应用介绍 企业微信是腾讯微信团队打造的企业通讯与办公工具&#xff0c;具有与微信一致的沟…

医院检验LIS系统源码

医院lis源码 实验室信息管理系统源码 .net检验系统源码 医院系统源码 了解更多源码内容&#xff0c;可私信我。 开发环境&#xff1a;.NET4.0 WPF VS2017或VS2019SQL2016 实验室信息管理系统以条码标本为主线&#xff0c;实现从采集、检测、报告、归档的全程跟踪管理。 支持…

DevOps自动化测试的原则和实践

DevOps是为了在保证高质量的前提下缩短系统变更从提交到部署至生产环境的时间。在对系统进行变更时&#xff0c;质量很重要。高质量才能让业务价值传递到系统干系人。『自动化测试既是提高质量的一种重要手段&#xff0c;也是实施持续测试必需的能力&#xff0c;因此它是DevOps…

修改WebBrowser控件的内核解决方案

首先说一下原理 当下很大浏览器他们都是用了IE的core, 这个core只提供HTML/JS的执行和渲染,并没有给出关于界面和一些特性上的事,所以开发自己浏览器如果基于IE core需要自己完成这些内容。 一张图很好的说明了这个情况,IE浏览器的架构:http://msdn.microsoft.com/en-us/li…

nginx - 负载均衡配置-负载均衡策略

目录 知识点1&#xff1a;网站流量分析指标 什么是pv&#xff1f; 什么是uv&#xff1f; 什么是IP&#xff1f; 知识点2&#xff1a;正向代理和反向代理 知识点3&#xff1a;负载均衡实验 IP地址规划&#xff1a; 实验拓扑图 知识点4&#xff1a;负载均衡策略 1、请求…

Spring5.3学习——from 官网 day1-1

Spring5.3学习——from 官网day1-1Spring5.3学习——from 官网day1-1前言概述Spring的设计理念Spring核心&#xff1a;IOC什么是IOC解释IOC容器的包什么是BeanBeanFactory接口简述ApplicationContext接口简述BeanFactory源码描述以下是Bean工厂创建和销毁bean的完整生命周期流程…

Matlab论文插图绘制模板第48期—平行坐标图(Parallelplot)

​上一期文章中&#xff0c;分享了Matlab帕累托图的绘制模板&#xff1a; 这一次&#xff0c;再来分享一种特殊的线图&#xff1a;平行坐标图。 ‘平行坐标图是一种通常的可视化方法&#xff0c;用于对高维几何和多元数据的可视化……为了克服传统的笛卡尔直角坐标系容易耗尽空…

好心情精神心理科:80%双相情感障碍被误诊,千万注意鉴别

双相情感障碍又称躁郁症&#xff0c;其表现复杂&#xff0c;容易与其他精神疾病&#xff08;包括边缘型人格障碍&#xff09;相混淆&#xff0c;超过80%的患者未能得到正确诊断。 具体如何区分双相情感障碍与边缘型人格障碍&#xff1f;在回答这个问题之前&#xff0c;好心情精…

从规模走向规模经济,锅圈食汇回归餐饮初心

预制菜源自美国&#xff0c;在日本因冷链技术发展而普及。后疫情时代&#xff0c;预制菜在中国餐饮市场加速渗透&#xff0c;成为行业的新风向。 9月&#xff0c;第一财经与CBNData发布“Growth502022中国新消费品牌年度增长力榜单”&#xff0c;预制菜品牌锅圈食汇入选。 锅…

设计模式学习笔记--责任链模式

责任链模式 责任链模式是一种对象的行为模式。在责任链模式里&#xff0c;很多对象由每一个对象对其下家的引用而连接起来形成一条链。请求在这个链上传递&#xff0c;直到链上的某一个对象决定处理此请求。发出这个请求的客户端并不知道链上的哪一个对象最终处理这个请求&…

Tuxera NTFS21Mac苹果电脑读取硬盘磁盘软件

我们经常会使用移动硬盘或 U 盘进行大体积文件的分享、携带。但有时候别人提供的NTFS移动硬盘或者U 盘在 Mac 电脑中只能读取&#xff0c;无法将文件导入到其中。这是因为常见的 NTFS 硬盘格式在 Mac 中不能兼容。 当你从 Windows 转到了 Mac 平台&#xff0c;可能会发现之前用…

RocketMQ-流程图-概念

文章目录RocketMq的角色消息发送的流程RocketMq的角色 Producer&#xff1a;消息的发送者&#xff0c;生产者&#xff1b;举例&#xff1a;发件人Consumer&#xff1a;消息接收者&#xff0c;消费者&#xff1b;举例&#xff1a;收件人Broker&#xff1a;暂存和传输消息的通道…