PyTorch: 池化-线性-激活函数层

news/2024/4/29 17:28:31/文章来源:https://blog.csdn.net/m0_52316372/article/details/131735747

文章和代码已经归档至【Github仓库:https://github.com/timerring/dive-into-AI 】或者公众号【AIShareLab】回复 pytorch教程 也可获取。

文章目录

  • nn网络层-池化-线性-激活函数层
    • 池化层
      • 最大池化:nn.MaxPool2d()
      • nn.AvgPool2d()
      • nn.MaxUnpool2d()
      • 线性层
      • 激活函数层
        • nn.Sigmoid
        • nn.tanh
        • nn.ReLU(修正线性单元)
        • nn.LeakyReLU
        • nn.PReLU
        • nn.RReLU

nn网络层-池化-线性-激活函数层

池化层

池化的作用则体现在降采样:保留显著特征、降低特征维度,增大 kernel 的感受面。 另外一点值得注意:pooling 也可以提供一些旋转不变性。 池化层可对提取到的特征信息进行降维,一方面使特征图变小,简化网络计算复杂度并在一定程度上避免过拟合的出现;一方面进行特征压缩,提取主要特征。

池化可以实现一个冗余信息的剔除,以及减少后面的计算量。

最大池化:nn.MaxPool2d()

nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

这个函数的功能是进行 2 维的最大池化,主要参数如下:

  • kernel_size:池化核尺寸
  • stride:步长,通常与 kernel_size 一致
  • padding:填充宽度,主要是为了调整输出的特征图大小,一般把 padding 设置合适的值后,保持输入和输出的图像尺寸不变。
  • dilation:池化间隔大小,默认为 1。常用于图像分割任务中,主要是为了提升感受野
  • ceil_mode:默认为 False,尺寸向下取整。为 True 时,尺寸向上取整
  • return_indices:为 True 时,返回最大池化所使用的像素的索引,这些记录的索引通常在反最大池化时使用,把小的特征图反池化到大的特征图时,每一个像素放在哪个位置。

下图 (a) 表示反池化,(b) 表示上采样,© 表示反卷积。

平均池化与最大池化的差距一般体现在图像的整体亮度上。由于最大池化取得是最大值,因此在亮度上一般是大于平均池化结果的。

下面是最大池化的代码:

import os
import torch
import torch.nn as nn
from torchvision import transforms
from matplotlib import pyplot as plt
from PIL import Image
from common_tools import transform_invert, set_seedset_seed(1)  # 设置随机种子# ================================= load img ==================================
path_img = os.path.join(os.path.dirname(os.path.abspath(__file__)), "imgs/lena.png")
img = Image.open(path_img).convert('RGB')  # 0~255# convert to tensor
img_transform = transforms.Compose([transforms.ToTensor()])
img_tensor = img_transform(img)
img_tensor.unsqueeze_(dim=0)    # C*H*W to B*C*H*W# ================================= create convolution layer ==================================# ================ maxpool
flag = 1
# flag = 0
if flag:maxpool_layer = nn.MaxPool2d((2, 2), stride=(2, 2))   # input:(i, o, size) weights:(o, i , h, w)img_pool = maxpool_layer(img_tensor)print("池化前尺寸:{}\n池化后尺寸:{}".format(img_tensor.shape, img_pool.shape))
img_pool = transform_invert(img_pool[0, 0:3, ...], img_transform)
img_raw = transform_invert(img_tensor.squeeze(), img_transform)
plt.subplot(122).imshow(img_pool)
plt.subplot(121).imshow(img_raw)
plt.show()

结果和展示的图片如下:

池化前尺寸:torch.Size([1, 3, 512, 512])
池化后尺寸:torch.Size([1, 3, 256, 256])

nn.AvgPool2d()

torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)

这个函数的功能是进行 2 维的平均池化,主要参数如下:

  • kernel_size:池化核尺寸
  • stride:步长,通常与 kernel_size 一致
  • padding:填充宽度,主要是为了调整输出的特征图大小,一般把 padding 设置合适的值后,保持输入和输出的图像尺寸不变。
  • dilation:池化间隔大小,默认为 1。常用于图像分割任务中,主要是为了提升感受野
  • ceil_mode:默认为 False,尺寸向下取整。为 True 时,尺寸向上取整
  • count_include_pad:在计算平均值时,是否把填充值考虑在内计算
  • divisor_override:除法因子。在计算平均值时,分子是像素值的总和,分母默认是像素值的个数。如果设置了 divisor_override,把分母改为 divisor_override。
img_tensor = torch.ones((1, 1, 4, 4))
avgpool_layer = nn.AvgPool2d((2, 2), stride=(2, 2))
img_pool = avgpool_layer(img_tensor)
print("raw_img:\n{}\npooling_img:\n{}".format(img_tensor, img_pool))

输出如下:

raw_img:
tensor([[[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]]]])
pooling_img:
tensor([[[[1., 1.],[1., 1.]]]])

加上divisor_override=3后,输出如下:

raw_img:
tensor([[[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]]]])
pooling_img:
tensor([[[[1.3333, 1.3333],[1.3333, 1.3333]]]])

nn.MaxUnpool2d()

nn.MaxUnpool2d(kernel_size, stride=None, padding=0)

功能是对二维信号(图像)进行最大值反池化,主要参数如下:

  • kernel_size:池化核尺寸
  • stride:步长,通常与 kernel_size 一致
  • padding:填充宽度

代码如下:

# pooling
img_tensor = torch.randint(high=5, size=(1, 1, 4, 4), dtype=torch.float)
maxpool_layer = nn.MaxPool2d((2, 2), stride=(2, 2), return_indices=True)
# 注意这里是保存了最大值所在的索引
img_pool, indices = maxpool_layer(img_tensor)# unpooling
img_reconstruct = torch.randn_like(img_pool, dtype=torch.float)
maxunpool_layer = nn.MaxUnpool2d((2, 2), stride=(2, 2))
img_unpool = maxunpool_layer(img_reconstruct, indices)print("raw_img:\n{}\nimg_pool:\n{}".format(img_tensor, img_pool))
print("img_reconstruct:\n{}\nimg_unpool:\n{}".format(img_reconstruct, img_unpool))

输出如下:

# pooling
img_tensor = torch.randint(high=5, size=(1, 1, 4, 4), dtype=torch.float)
maxpool_layer = nn.MaxPool2d((2, 2), stride=(2, 2), return_indices=True)
img_pool, indices = maxpool_layer(img_tensor)# unpooling
img_reconstruct = torch.randn_like(img_pool, dtype=torch.float)
maxunpool_layer = nn.MaxUnpool2d((2, 2), stride=(2, 2))
img_unpool = maxunpool_layer(img_reconstruct, indices)print("raw_img:\n{}\nimg_pool:\n{}".format(img_tensor, img_pool))
print("img_reconstruct:\n{}\nimg_unpool:\n{}".format(img_reconstruct, img_unpool))

线性层

线性层又称为全连接层,其每个神经元与上一个层所有神经元相连,实现对前一层的线性组合或线性变换。

代码如下:

inputs = torch.tensor([[1., 2, 3]])
linear_layer = nn.Linear(3, 4)
linear_layer.weight.data = torch.tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.]])linear_layer.bias.data.fill_(0.5)
output = linear_layer(inputs)
print(inputs, inputs.shape)
print(linear_layer.weight.data, linear_layer.weight.data.shape)
print(output, output.shape)

输出为:

tensor([[1., 2., 3.]]) torch.Size([1, 3])
tensor([[1., 1., 1.],[2., 2., 2.],[3., 3., 3.],[4., 4., 4.]]) torch.Size([4, 3])
tensor([[ 6.5000, 12.5000, 18.5000, 24.5000]], grad_fn=<AddmmBackward>) torch.Size([1, 4])

激活函数层

假设第一个隐藏层为: H 1 = X × W 1 H_{1}=X \times W_{1} H1=X×W1,第二个隐藏层为: H 2 = H 1 × W 2 H_{2}=H_{1} \times W_{2} H2=H1×W2,输出层为:

Output  = H 2 ∗ W 3 = H 1 ∗ W 2 ∗ W 3 = X ∗ ( W 1 ∗ W 2 ∗ W 3 ) = X ∗ W \begin{aligned} \text { Output } &=\boldsymbol{H}_{\mathbf{2}} * \boldsymbol{W}_{\mathbf{3}} \\ &=\boldsymbol{H}_{1} * \boldsymbol{W}_{\mathbf{2}} * \boldsymbol{W}_{\mathbf{3}} \\ &=\boldsymbol{X} *\left(\boldsymbol{W}_{1} * \boldsymbol{W}_{\mathbf{2}} * \boldsymbol{W}_{3}\right) \\ &=\boldsymbol{X} * \boldsymbol{W} \end{aligned}  Output =H2W3=H1W2W3=X(W1W2W3)=XW

如果没有非线性变换,由于矩阵乘法的结合性,多个线性层的组合等价于一个线性层。

激活函数对特征进行非线性变换,赋予了多层神经网络具有深度的意义。下面介绍一些激活函数层。

nn.Sigmoid

  • 计算公式: y = 1 1 + e − x y=\frac{1}{1+e^{-x}} y=1+ex1
  • 梯度公式: y ′ = y ∗ ( 1 − y ) y^{\prime}=y *(1-y) y=y(1y)
  • 特性:
    • 输出值在(0,1),符合概率
    • 导数范围是 [0, 0.25],容易导致梯度消失
    • 输出为非 0 均值,破坏数据分布

nn.tanh

  • 计算公式: y = sin ⁡ x cos ⁡ x = e x − e − x e − + e − x = 2 1 + e − 2 x + 1 y=\frac{\sin x}{\cos x}=\frac{e{x}-e{-x}}{e{-}+e{-x}}=\frac{2}{1+e^{-2 x}}+1 y=cosxsinx=e+exexex=1+e2x2+1
  • 梯度公式: y ′ = 1 − y 2 y{\prime}=1-y{2} y=1y2
  • 特性:
    • 输出值在(-1, 1),数据符合 0 均值
    • 导数范围是 (0,1),容易导致梯度消失

nn.ReLU(修正线性单元)

  • 计算公式: y = m a x ( 0 , x ) y=max(0, x) y=max(0,x)
  • 梯度公式: y ′ = { 1 , x > 0 undefined,  x = 0 0 , x < 0 y^{\prime}=\left\{\begin{array}{ll} 1, & x>0 \\ \text { undefined, } & x=0 \\ 0, & x<0 \end{array}\right. y= 1, undefined, 0,x>0x=0x<0
  • 特性:
    • 输出值均为正数,负半轴的导数为 0,容易导致死神经元
    • 导数是 1,缓解梯度消失,但容易引发梯度爆炸

针对 RuLU 会导致死神经元的缺点,出现了下面 3 种改进的激活函数。

nn.LeakyReLU

  • 有一个参数negative_slope:设置负半轴斜率

nn.PReLU

  • 有一个参数init:设置初始斜率,这个斜率是可学习的

nn.RReLU

R 是 random 的意思,负半轴每次斜率都是随机取 [lower, upper] 之间的一个数

  • lower:均匀分布下限
  • upper:均匀分布上限

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_331263.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux 下如何安装 tar.gz包

linux 下如何安装 tar.gz包 解压缩进入解压后的文件目录下 解压缩 tar -zxvf pycharm-community-2023.1.3.tar.gz进入解压后的文件目录下 ./pycharm.sh可执行Pycharm 建议将目录转移到其他位置 我习惯使用2020版本的 下载地址

源码阅读: echo 回显程序

文章目录 1. 目的2. 原始代码3. 化简和跨平台支持4. 修改后代码的代码分析5. References 1. 目的 阅读 netbsd 9.3 的 echo.c, 练习 C 语言源码阅读的技能。 2. 原始代码 https://github.com/NetBSD/src/blob/trunk/bin/echo/echo.c /* $NetBSD: echo.c,v 1.23 2021/11/16 …

2023年Java最新面试题

由【后端面试题宝典】提供 和 equals 的区别是什么&#xff1f; 对于基本类型&#xff0c;比较的是值&#xff1b;对于引用类型&#xff0c;比较的是地址&#xff1b;equals不能用于基本类型的比较&#xff1b;如果没有重写equals&#xff0c;equals就相当于&#xff1b;如果重…

基于JavaSwing+Mysql的仓库销售管理系统

点击以下链接获取源码&#xff1a; https://download.csdn.net/download/qq_64505944/88049275 JDK1.8 MySQL5.7 功能&#xff1a;管理员与员工两个角色登录&#xff0c;基础数据查找&#xff0c;仓库查找&#xff0c;增删改查仓库信息、商品等 源码数据库文件配置文件课程设…

5分钟构建电商API接口服务 | python小知识

1. 什么是API 我们经常会使用一些API接口来完成特定的功能&#xff0c;比如查询天气的数据&#xff0c;下载股票的数据&#xff0c;亦或是调用ChatGPT模型的结构等等。 API全称是Application Programming Interface&#xff0c;即应用程序接口&#xff0c;它通常提供了一个功…

Mysql单表多表查询练习

题目要求&#xff1a; 1.查询student表的所有记录 2.查询student表的第2到4条记录 3.从student表查询所有的学生的学号&#xff08;id&#xff09;&#xff0c;姓名&#xff08;name&#xff09;&#xff0c;和院系&#xff08;department&#xff09;的信息 4.从student表…

SpringAMQP - 消息传输时,如何提高性能?解决 SQL 注入问题?

目录 一、问题背景 二、从消息转化器根源解决问题 1.引入依赖 2.在服务生产者和消费者中都重新定义一个 MessageConverter&#xff0c;注入到 Spring 容器中 一、问题背景 在SpringAMQP的发送方法中&#xff0c;接收消息的类型是Object&#xff0c;也就是说我们可以发送任意…

用 GPU 并行环境 Isaac Gym + 强化学习库 ElegantRL:训练机器人Ant,3小时6000分,最高12000分

前排提醒,目前我们能 “用 ppo 四分钟训练 ant 到 6000 分”,比本文的 3 小时快了很多很多,有空会更新代码 https://blog.csdn.net/sinat_39620217/article/details/131724602 介绍了 Isaac Gym 库 如何使用 GPU 做大规模并行仿真,对环境模块提速。这篇帖子,我们使用 1 …

class文件反编译成Java文件

下载jad158g.win压缩包解压到指定的磁盘下&#xff08;下载目录&#xff1a;https://varaneckas.com/jad/&#xff09; 操作命令&#xff1a;jad -o -r -s java -d src classes/**/*.class 此处src为生成Java文件存放的文件夹&#xff0c;classes则为需要转化成Java的class文件…

docker 里面各种 command not found 总结

一、ip&#xff1a;command not found 执行命令&#xff1a; apt-get update & apt-get install -y iproute2 二、yum&#xff1a;command not found 执行命令&#xff1a; apt-get update & apt-get install -y yum 三、ping&#xff1a;command not found 执行命…

Openlayers实战:多地图底图切换

在实际的地图项目中,不管是我们看到的百度地图还是高德地图等,都会有地图切换这一项。 在Openlayers实战中,我们用三种地图做demo,分别是谷歌地图。Openstreetmap,stamen地图。 切换的主要原则是设置三个底图层,设定其显示状态,用到了visible这一个属性。 效果图 源代码…

OpenCV for Python 入坑第一天:图像的基础操作

我们都知道&#xff0c;OpenCV能够帮助我们处理视频和图像&#xff0c;咱们在图像处理中&#xff0c;除了Pillow库之外&#xff0c;最经常用到的也是它了。那么现在咱们就正式入坑OpenCV for Python&#xff0c;一起来感受一下OpenCV的魅力吧&#xff01; 文章目录 读取图像 im…

02LINGO基本操作

某公司新购置了某种设备 6 台&#xff0c;欲分配给下属的4 个企业&#xff0c;已知各企业获得这种设备后年创利润如表 1.1 所示&#xff0c;单位为千万元。问应如何分配这些设备能使年创总利润最大&#xff0c;最大利润是多少? 甲乙丙丁1423426455376764788657986671086 甲公…

国产芯片——单片机32位mcu的应用

随着物联网与人工智能和智能制造的发展&#xff0c;单片机作为嵌入式系统的核心控制器&#xff0c;在各类行业应用中占据重要位置。其中32位MCU在芯片设计、制造工艺、封装技术上等取得显著突破&#xff0c;以高性能的技术条件被广泛应用在智能物联等行业的方案开发中。今天我们…

txt文本筛选—python操作

需求&#xff1a;若文档中某行最后一列内容为0&#xff0c;则删除该行&#xff0c;否则保留该行内容&#xff0c;并将筛选后的内容保存到新的文本文档中。 # 读取原始txt文件 with open(depth_values.txt, r) as file:lines file.readlines()# 过滤掉第三列内容为0的行 filter…

Ubuntu22.04密码忘记怎么办 Ubuntu重置root密码方法

在Ubuntu 22.04 或其他更高版本上不小心忘记root或其他账户的密码怎么办&#xff1f; 首先uname -r查看当前系统正在使用的内核版本&#xff0c;记下来 前提&#xff1a;是你的本地电脑&#xff0c;有物理访问权限。其他如远程登录的不适用这套改密方法。 通过以下步骤&#…

解决打开excel时报错 “不能使用对象链接和嵌入”

问题截图 打开excel文件或者插入对象时&#xff0c;直接弹出不能使用对象链接和嵌入报错信息。 解决方法 按 winr 组合快捷键&#xff0c;打开运行&#xff0c;输入 dcomcnfg.exe 按回车确定 此时进入到组件服务管理界面&#xff0c;依次选择 组件服务-计算机-我的电脑-DOCM…

LaTex 1【字体、符号、表格】

​&#xff08;上一期已经安装完软件&#xff0c;但是突然出现了一点子问题不会解决&#xff0c;先用overleaf来学习吧&#xff0c;网站还是很给力的&#xff09; 关键字部分&#xff1a; \quad:代表空格【无论你打多少个空格都不是空格&#xff0c;要输入“\quad”】 字体部分…

ECMAScript 6 之二

目录 2.6 Symbol 2.7 Map 和 Set 2.8 迭代器和生成器 2.9 Promise对象 2.10 Proxy对象 2.11 async的用法 2.22 类class 2.23 模块化实现 2.6 Symbol 原始数据类型&#xff0c;它表示是独一无二的值。它属于 JavaScript 语言的原生数据类型之一&#xff0c;其他数据类型…

6.2.8 网络基本服务----万维网(www)

6.2.8 网络基本服务----万维网&#xff08;www&#xff09; 万维网即www&#xff08;World Wide Web&#xff09;是开源的信息空间&#xff0c;使用URL也就是统一资源标识符标识文档和Web资源&#xff0c;使用超文本链接互相连接资源&#xff0c;万维网并非某种特殊的计算机网…