【PytorchLearning】Dropout | 内部机制及代码复现

news/2024/4/19 22:21:20/文章来源:https://blog.csdn.net/weixin_43427721/article/details/127677650

Dropout

1.CLASS torch.nn.Dropout(p=0.5, inplace=False)

训练过程中按照概率p随机地将输入张量中的元素置为0

evere channel will be zeroed out independently on every forward call.

Parameters:

  • p(float):每个元素置为0的概率,默认是0.5
  • inplace(bool):是否对原始张量进行替换

Shape

  • intput(*):any shape
  • out(*):the same shape as input

Examples

m = nn.Dropout(p=0.2)
input = torch.randn(20, 16)
output = m(input)

2.TORCH.NN.FUNCTIONAL.DROPOUT

torch.nn.functional.dropout(input, p=0.5, training=True, inplace=False),内部细节与Dropout相同

Parameters:

  • p (float) – probability of an element to be zeroed. Default: 0.5
  • training (bool) – apply dropout if is True. Default: True
  • inplace (bool) – If set to True, will do this operation in-place. Default: False

Retrun Type:

Tensor

3.Question

  1. Dropout中没有training这个参数,那么他怎么区分train和test?
  2. 为什么Dropout在训练的时候的推理的时候运算逻辑不一样?
  3. pytorch内部是如何实现Dropout的(训练、推理)?
  4. Dropout在训练和推理过程中有较大的区别,那么如何去改进?
  5. 为什么要用Dropout,Dropout在网络中的直观影像是什么?

1.Dropout继承自torch.nn.module,torch.nn.module内置的成员变量就包含training选项,其默认值为True。当我们训练模型时,不用指定就内置为True;当推理模型时,model.eval()会自动将training设置为False,从而不采用dropout进行推理(也不采用BN)

2.因为dropout是带有随机性的,如果 infer 也做的话,网络的输出就不稳定(同样一个样本,整体预测结果每次都可能变化)

3.主要是用c实现的,包括二项伯努利分布、mask操作

4.使用Inverted Dropout的方式进行改进,只在训练过程中对数据分布进行改动,即先dropout再rescale,保证总期望不变

5 使用dropout相当于在引入多个不同的模型,可以使网络具有更好的泛化性能从而避免过拟合

Furthermore, the outputs are scaled by a factor of
frescale=11−pf_{rescale}=\frac {1}{1-p} frescale=1p1
during training. This means that during evaluation the module simply computes an identity function.

首先,dropout是带有随机性的,如果 infer 也做的话,网络的输出就不稳定(同样一个样本,整体预测结果每次都可能变化)。在 infer 不做 dropout 的前提下,为了保证训练和预测过程的分布一致,需要对 infer 进行 rescale,也就是原始论文中将infer数据进行1-p倍缩小的做法,这种方式会导致预测过程依赖训练过程,模型推理的变动较大;于是Inverted Dropout提出只在训练过程中对数据分布进行修改,即先遮盖掉p的节点,然后再放大为1/(1-p)倍,这样在infer的过程中就不必对数据进行变动。即训练过程中随机扔掉了一些节点,但是rescale之后总期望又被拉回到了原来的水平。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HdMVGnbM-1667478261933)(C:\Users\marlowe\Desktop\paperReading\神经网络基础知识\IMG\2-1.png)]

numpy实现

1.vanilla dropout

imoport numpy as npdef van_train(reate,x,w1,b1,w2,b2):layer1=np.maxinum(0,np.dot(w1,x)+b1)mask1=np.random.binomial(1,1-rate,layer1.shape)# random.binomial(n, p, size=None)layer1=layer1*mask1layer2=np.maxinum(0,np.dot(w2,layer1)+b2)mask2=np.random.binomial(1,1-rate,layer2.shape)# random.binomial(n, p, size=None)layer2=layer2*mask2return layer2def van_test(rate,x,w1,b1,w2,b2):layer1=np.maxinum(0,np.dot(w1,x)+b1)layer1=layer1*(1-rate)layer2=np.maximun(0,np.dot(w2,layer1)+b2)layer2=layer2*(1-rate)return layer2

2.inverted dropout

import numpy as npdef inv_train(rate,x,w1,b1,w2,b2):layer1=np.maxinum(0,np.dot(w1,x)+b1)mask1=np.random.binomial(1,1-rate,layer1.shape)# random.binomial(n, p, size=None)layer1=layer1*mask1layer1/=1-ratelayer2=np.maxinum(0,np.dot(w2,layer1)+b2)mask2=np.random.binomial(1,1-rate,layer2.shape)# random.binomial(n, p, size=None)layer2=layer2*mask2layer2/=1-ratereturn layer2def inv_test(x,w1,b1,w2,b2):# 不需要使用rate进行缩放layer1=np.maxinum(0,np.dot(w1,x)+b1)layer2=np.maximun(0,np.dot(w2,layer1)+b2)return layer2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_411663.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BGP BFD测试案例

一、BFD原理 1.1 BFD技术简介 一种全网统一、检测迅速、监控网络中链路或者IP路由的双向转发连通状况,并未上层应用提供服务的技术。 1.2 BFD会话建立方式和监测机制 ●BFD的标识符: (1)BFD建立会话存在标识符的概念&#xff…

中小企业数字化思考:数字化转型应该走自己的路

随着数字化的发展,以及数字中国概念的形成,和以前国央企宣布数字化转型时的不同,现在越来越多的企业开始寻求数字化转型,促使自身业务能够更好的发展。现在看过去,各行各业都有大量企业进行了数字化转型规划&#xff0…

【Mac】VSCode 更新1.73版本后JSTS代码跳转异常

前言 今天有小伙伴MacOS更新了VS Code版本后,说工程内的代码跳转全部异常了,没法正确跳转。搞了两三个小时没搞出来,找到了我,让我帮忙瞧瞧。排查下来发现这问题有点意思,故此记录一下。 问题 排查姿势 1. 提示没有定…

Skywalking9.2.0监控浏览器

Skywalking9.2.0监控浏览器 安装skywalking-client-js npm install skywalking-client-js --save在main.js添加信息 import ClientMonitor from skywalking-client-jsrouter.afterEach(() > {ClientMonitor.setPerformance({service: 服务名,serviceVersion: 版本号,pagePat…

基于模糊小波神经网络的空中目标威胁评估(Matlab代码实现)

目录 💥1 概述 📚2 运行结果 🎉3 参考文献 👨‍💻4 Matlab代码 💥1 概述 在现代战争中, 随着信息化和智能化的飞速发展, 以及作战环境的日益复杂, 实时而准确地评估目标威胁, 不仅为空战决策提供科学的…

程序人生:技术水平低,就这还敢写自动化项目实战经验丰富?

今年部门要招两个自动化测试,这几个月我面试了几十位候选人。发现一个很奇怪的现象,面试中一问到元素定位、框架api、脚本编写之类的,很多候选人都对答如流。但是一问到实际项目,比如 “如何从0开始搭建自动化体系”、“如果让你来…

资深大牛纯手写RabbitMQ 核心笔记,还有谁?

RabbitMQ简介 RabbitMQ是消息代理(Message Broker),它支持多种异步消息处理方式,最常见的有: Work Queue:将消息缓存到一个队列,默认情况下,多个worker按照Round Robin的方式处理队列中的消息。每个消息只…

CART回归树算法

【题目1】 表1为拖欠贷款人员训练样本数据集,使用CART算法基于该表数据构造决策树模型,并使用表2中测试样本集确定剪枝后的最优子树。 表1 拖欠贷款人员训练样本数据集编号 房产状况 婚姻情况 年收(千元) 拖欠贷款1 是 单身 125 否2 否 已婚 100 否3 否 单身 70 否4 是 已婚…

一本通1064;奥运奖牌计数

#include <iostream> using namespace std; int main() {int n, Jin, Yin, Tong;int JinSum 0, YinSum 0, TongSum 0, sum;cin >> n;for (int i 1; i < n; i) // 循环n次{cin >> Jin >> Yin >> Tong; // 输入一天获得的金银铜牌数JinSum …

IR信息检索前沿梳理

1. 检索预训练 1.1 PROP: Pre-training with Representative Words Prediction for Ad-hoc Retrieval three types of pre-training tasks have been proposed including: Inverse Cloze Task (ICT): The query is a sentence randomly drawn from the passage and the docu…

全志F1C芯片参数对比,供查阅

F1C600特性介绍 组合32M DDR1&#xff0c;QFN编解码模式&#xff0c;生产音频核心板&#xff08;CPUNORWIFI&#xff09;在WIFI站下播放的功率约0.5W组合I2S、SPDIF、CODEC等多功能接口支持全格式音频解码芯片 F1C600参数介绍 中央处理器 ARM926EJ-S 内存 SIP DDR1 SD2.0…

月入18000,0基础转行软件测试,实现薪资翻倍我只用了135天

在没做测试之前&#xff0c;我一直是个没自信的人&#xff0c;因为工作不稳定&#xff0c;收入也不高。 大学毕业做了2年酒店管理&#xff0c;月入4000提成&#xff0c;还经常上夜班&#xff0c;熬人又伤身体&#xff0c;于是不想再做服务行业&#xff0c;就转行做了电销。这之…

本地数据库IndexedDB - 学员管理系统之列表管理(二)

IndexedDB是浏览器提供的本地数据库&#xff0c;它可以被网页脚本创建和操作。IndexedDB允许存储大量数据&#xff0c;提供查找接口&#xff0c;还能建立索引。这些都是LocalStorage或Cookie不具备的。就数据库类型而言&#xff0c;IndexedDB不属于关系型数据库&#xff08;不支…

使用VMware16克隆功能快速准备CentOS 7.9操作系统集群

记录&#xff1a;305 场景&#xff1a;使用VMware16克隆功能快速准备CentOS 7.9操作系统集群&#xff0c;主要内容&#xff1a;VMware16克隆功能功能使用、CentOS 7.9操作系统常用指令使用、制作本地yum源、安装JDK、配置集群NTP时间同步等。 版本&#xff1a; 虚拟机工具&a…

数据结构-难点突破(C++/Java详解实现串匹配算法KMP,next数组求法,KMP算法优化nextval数组)

文章目录1. 暴力匹配算法BF2. KMP算法next数组求法Java代码&#xff1a;C代码&#xff1a;KMP算法优化nextval数组1. 暴力匹配算法BF 在了解KMP算法前&#xff0c;就必须介绍串的暴力匹配算法&#xff08;BF算法&#xff09; BF算法&#xff0c;即暴力(Brute Force)算法&…

大赛征集令|首届“万应杯”低代码应用开发大赛报名开启啦!

探索&#xff0c;寻觅低码边界。 创新&#xff0c;做成未曾有人做过的事。 首届“万应杯”低代码应用开发大赛 报名正式启动啦&#xff01; 万元现金奖杯/证书项目转售收益 丰厚奖励&#xff0c;邀你来战&#xff01; 大赛时间 低码掘金&#xff0c;就在此时&#xff01; …

MySQL高级SQL语句(一)

MySQL高级SQL语句&#xff08;一&#xff09;MySQL高级SQL语句&#xff08;一&#xff09;一、高级SQL语句&#xff08;进阶查询&#xff09;1.1 select1.2 distinct1.3 where1.4 and 、or1.5 in1.6 between1.7 通配符1.8 like1.9 order by二、函数2.1 数学函数2.2 聚合函数2.3…

MSDC 4.3 接口规范(26)

MSDC 4.3 接口规范&#xff08;26&#xff09;7.4 组呼业务管理7.4.1 服务状态7.4.2 启动组呼业务7.4.2.1 接口函数7.4.2.2 先决条件7.4.2.3 说明7.4.2.4 调用流程7.4.2.4.1 启动组呼业务7.4.2.4.2 无法启动服务7.4.3 停止组呼服务7.4.3.1 接口函数7.4.3.2 先决条件7.4.3.3 说明…

SH-SSS丨《端到端音视频说话人日志网络》论文线上分享

SH Symposium Series on Speech (SH SSS 2022) SH SSS 是由语音之家打造的AI语音技术相关的前沿论文成果分享平台。 来自AI语音技术领域的优秀论文作者、专家学者&#xff0c;用最精炼的表达来解读最新的高质量论文。 分享的论文成果来自国内外顶级会议收录的优秀文章、前沿…