Pytorch的CNN,RNNLSTM

news/2024/3/29 20:37:46/文章来源:https://blog.csdn.net/qq_36102055/article/details/130348772

CNN

拿二维卷积举例,我们先来看参数
在这里插入图片描述
卷积的基本原理,默认你已经知道了,然后我们来解释pytorch的各个参数,以及其背后的计算过程。
首先我们先来看卷积过后图片的形状的计算:
参数:

kernel_size :卷积核的大小,可以是一个元组,也就是(行大小,列大小)
stride : 移动步长,同样的可以是一个元组
padding:填充,同样的可以是一个元组。注意,填充是两边都填,如果原本宽是x,填充1后,宽会变成x+2,因为左右都填充了1
dilation : 空洞卷积大小

空洞卷积:

这里解释一下dilation,你可以认为他就是把卷积核之间加了一些洞,在不改变参数量的情况下增加了感受野。
下面两张图片引用自下面的链接,想详细了解空洞卷积的,可以看这篇文章。
https://zhuanlan.zhihu.com/p/50369448
dilation=1,此时卷积核的每个参数都是相邻的
在这里插入图片描述
dilation=2,此时卷积核的每个参数和上下左右都有一个空隙,这个空隙是一种间隔并不是一种参数,所以从形式上来看,卷积核变大变稀疏了,能够捕获更大的范围,但参数量没有变化。
![在这里插入图片描述](https://img-blog.csdnimg.cn/3dd7a254c03c4c89afe7e6257164aa78.png
所以,假设原卷积核形状是(a, b)那么dilation=k时,行增加了(k - 1)(a - 1),列增加了(k - 1)(b - 1)。
于是卷积核变成了(a + (k - 1)(a - 1), b + (k - 1)(b - 1))

输出行的计算:

我们来推导行的计算,列的计算同理。
我们可以把行分成两个部分第一个部分是卷积核做的第一次卷积,第二个部分是卷积核向右移动做的卷积。
此时行的大小应该是1 + 向右移动的次数。向右移动的次数应该等于(行大小 - 卷积核宽度) / 步长 再向下取整
行大小 = 图片宽度 + pandding[0] * 2
卷积核宽度 = kernel_size[0] + (dilation[0] - 1)(kernel_size[0] - 1)
我们把卷积核宽度计算展开得到:
卷积核宽度 = dilation[0](kernel_size[0] - 1) + 1
于是把上述公式整合就得到pytorch里的计算式
在这里插入图片描述

分组卷积:

我们可以看到Pytorch的CNN里有一个参数叫groups,是为分组卷积做准备的参数。
分组卷积的思路是吧in_channel和out_channel都进行分组,每个组内分别进行卷积运算,最后再把结果进行concat。下面举个例子:
输入in_channel = 4, out_channel = 6, groups = 2。
此时in_channel分成了两组,每组的channel数都是in_channel/groups = 2,同理out_channel 也分成了两组每组的channel数都是out_channel/groups=3
此时输入通道变成了in_channel_groups = (sub_in_channel1=2, sub_in_channel1=2)
输出通道变成了out_channel_groups = (sub_out_channel1=3, sub_out_channel1=3)
我们让sub_in_channel1和sub_out_channel1做卷积,sub_in_channel2和sub_out_channel2做卷积。
最后我们再把两个组的卷积结果concat一下,得到最终结果,第一个组的输出channel数是3,第二个组的数也是3,所以concat之后,输出的通道数还是6.

那么,分组卷积的意义何在呢?我们看参数数量就知道了。由于输出通道要除以groups,所以输入通道的数目会减少,此时每一个卷积核要负责的输入通道也就减少了,也就是说,卷积核的参数数目会减少groups倍,这就是分组卷积的意义。

参数量的计算

weight的大小为(out_channel, i n _ c h a n n e l g r o u p s \frac{in\_channel}{groups} groupsin_channel, kernel_size[0], kernel_size[1]),从这里也可以看出groups减少了参数量。
bias的大小就是out_channel的大小了。

示例代码:

conv = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=(3, 3),\stride=(2, 2), padding=(0, 0), groups=3)
"""
上述代码创建了一个处理3通道图片的卷积层
其中输出通道为3,卷积核形状为(3, 3)没有padding,横向和垂直的步长都是2
输出通道和输入通道都会被平均的分为3组
"""

RNN

同理,这里也默认了你知道了RNN的基础知识,我们来看Pytorch中RNN的官方文档。
在这里插入图片描述
这里我们先忽略num_layers, bias, 等参数,先专注于hidden_size和input_size。
先创建一个RNN然后查看参数

rnn = torch.nn.RNN(input_size = 3, hidden_size=5, bias=False)
print(rnn.all_weights)
"""
[[Parameter containing:
tensor([[-0.2646, -0.4045, -0.1925],[-0.3035, -0.4026, -0.2005],[ 0.0181,  0.0157, -0.0804],[ 0.4191,  0.0750,  0.1659],[ 0.1848,  0.1085,  0.4351]], requires_grad=True), Parameter containing:
tensor([[ 0.4092,  0.1956, -0.1648,  0.0278, -0.3483],[ 0.3865,  0.3441,  0.1004, -0.4226, -0.2988],[ 0.2640,  0.3169, -0.2568, -0.3115,  0.3268],[-0.3311, -0.1856,  0.3827, -0.1265, -0.4149],[ 0.0930, -0.1986,  0.1813,  0.3944,  0.1576]], requires_grad=True)]]
"""

可以看到有两个参数,一个是wi=(5, 3)一个是wh=(5, 5),第一个参数wi用于乘以输入的数据,而第二个参数wh用于和隐藏层的向量相乘,在结合官方给出的RNN计算,我们就可以手动模拟以下RNN的推理过程了,RNN有两个输出,一个是out_put,一个是h_0
手动模拟:

data = torch.randn(size=(2, 4, 3), dtype=torch.float32) # 输入的数据
rnn = torch.nn.RNN(input_size = 3, hidden_size=5, bias=False)
wh = rnn.all_weights[0][0]
wi = rnn.all_weights[0][1]
h0 = torch.zeros(size=(1, 4, 5))up_out = [] # out_put的数据
right_out = None # 隐藏层最终输出
acfun = torch.nn.Tanh()for i in range(2): # 2代表时序tmp1 = data[i].matmul(wh.T) # 计算输入tmp2 = h0.matmul(wi.T) # 计算隐藏层res = acfun(tmp1 + tmp2) # 计算新的h_0up_out.append(res) # 本时刻的输出h0 = res # 更新隐藏层
print(torch.cat(up_out, dim=0), h0) # 把每一个时刻的输出都concat一下
print(rnn(data))

其余参数:

其余的参数就好理解了,num_layer是控制层数的,这会使得隐藏层增加。Bidirectional是控制是否开双向的最后输出会被叠加在out_put上,batch_first是为了方便输入数据的batch位于第一个维度。

LSTM

LSTM比RNN要复杂很多,同样的,通过模拟LSTM计算的过程,搞清楚pytorch的LSTM是怎么输出的。
首先先回忆一下LSTM的计算过程:
在这里插入图片描述
LSTM有三个门,分别是:输入门,输出门,遗忘门。每个门都需要对应一个权重。
除此之外对于输入也还需要做一个做一个变换,所以也需要一个参数。
对于Pytorch的LSTM来说,他还考虑上一次的输出,也就是说,Pytorch的LSTM计算时是这样的
在这里插入图片描述也可以看到,上一次的输出 h t h_t ht被接到了下一次输入中。那么我们针对于新增的 h t h_t ht也需要对应的参数。
那么Pytorch的LSTM一共有八个参数(除去bias)

观察参数:

首先,对于input_size = a, hidden_size=b的LSTM,针对输入数据的矩阵的形状应该是bxa的。针对上一次输出数据的形状应该是bxb的。
我们来看Pytorch的LSTM的参数

rnn = nn.LSTM(input_size = 3, hidden_size=5, bias=False)
print(rnn.all_weights[0][0].shape, rnn.all_weights[0][1].shape)
"""
torch.Size([20, 3]) torch.Size([20, 5])
"""

我们发现只有两组,其实这是八组。因为针对于输入数据,一共有4个5x3的数据。对于 h t h_t ht一共有4个5x5的数据。为了加快运算,Pytorch把这些数据concat到一起了。于是出现了两个20列的数据。

模拟LSTM运算

根据Pytorch官方给的计算公式和,LSTM的计算图解。我们来进行模拟
在这里插入图片描述
在这里插入图片描述
首先我们先拆分出来参数

wii = rnn.all_weights[0][0][:5, :]
wif = rnn.all_weights[0][0][5:10, :]
wig = rnn.all_weights[0][0][10:15, :]
wio = rnn.all_weights[0][0][15:20, :]whi = rnn.all_weights[0][1][:5, :]
whf = rnn.all_weights[0][1][5:10, :]
whg = rnn.all_weights[0][1][10:15, :]
who = rnn.all_weights[0][1][15:20, :]

然后我们把LSTM的三个输入给构造出来

data = torch.randn(size=(2, 4, 3), dtype=torch.float32)
h = torch.zeros(size=(1, 4, 5)) # 隐藏层,也就是上一次的输出
c = torch.zeros(size=(1, 4, 5)) # 记忆单元

然后我们按照上面的公式进行计算,就会得到最终的结果

import torch
from torch import nndata = torch.randn(size=(2, 4, 3), dtype=torch.float32)
rnn = nn.LSTM(input_size = 3, hidden_size=5, bias=False)wii = rnn.all_weights[0][0][:5, :]
wif = rnn.all_weights[0][0][5:10, :]
wig = rnn.all_weights[0][0][10:15, :]
wio = rnn.all_weights[0][0][15:20, :]whi = rnn.all_weights[0][1][:5, :]
whf = rnn.all_weights[0][1][5:10, :]
whg = rnn.all_weights[0][1][10:15, :]
who = rnn.all_weights[0][1][15:20, :]h = torch.zeros(size=(1, 4, 5))
c = torch.zeros(size=(1, 4, 5))
up_out = []tanh = nn.Tanh()
sigmoid = nn.Sigmoid()for k in range(2):i = sigmoid(data[k].matmul(wii.T) + h.matmul(whi.T))f = sigmoid(data[k].matmul(wif.T) + h.matmul(whf.T))g = tanh(data[k].matmul(wig.T) + h.matmul(whg.T))o = sigmoid(data[k].matmul(wio.T) + h.matmul(who.T))c = c * f + i * go = tanh(c) * oup_out.append(o) # 本次的输出加到output中h = o # 把本次的输出替换h
print(torch.cat(up_out, dim=0), (h, c))
print(rnn(data))

模拟完之后我们就搞懂了Pytorch的LSTM的输入和输出都是啥了:
输入:三个参数,data,隐藏层初始值,记忆单元初始值
输出:两个值,第二个值是一个元组。每一个时刻的输出,隐藏层输出(就是最后一个时刻的输出),记忆单元的值
其余的参数和RNN一致就不再说了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_103256.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android 动画—补间动画

帧动画是通过连续播放图片来模拟动画效果,而补间动画开发者只需指定动画开始,以及动画结束"关键帧",而动画变化的"中间帧"则由系统计算并补齐! 1.补间动画的分类和Interpolator Andoird所支持的补间动画效果…

electron+vue3全家桶+vite项目搭建【14】electron多窗口,多语言切换不同步更新问题

文章目录 引入问题演示补充逻辑注意封装缓存工具类补充状态管理调整多语言初始化调整多语言切换组件 解决方案思路整理渲染进程监听语言切换主进程创建多语言切换处理语言切换组件通知主进程语言切换 最终实现效果演示 引入 我们之前在这篇文章中集成了 多语言切换&#xff0c…

【数据挖掘与商务智能决策】第十三章 数据降维之PCA 主成分分析

13.1.2 PCA主成分分析代码实现 1.二维空间降维Python代码实现 import numpy as np X np.array([[1, 1], [2, 2], [3, 3]]) Xarray([[1, 1],[2, 2],[3, 3]])# 也可以通过pandas库来构造数据,效果一样 import pandas as pd X pd.DataFrame([[1, 1], [2, 2], [3, 3…

数字北京城,航行在联通2000M的“大运河”

前故宫博物院院长单霁翔,在《大运河漂来紫禁城》一书中提到过,紫禁城里的石材、木材,甚至每一块砖,都是通过大运河,跋山涉水来到北京的。某种程度上说,北京城的繁荣与这条纵跨南北的“中华大动脉”密不可分…

AntdesignVue 局部全屏后Message、Select 、Modal、Date等组件不显示问题解决方案(最终版)

1、对this.$message.....这种的消息提示组件解决方案如下 在main.js中全局配置消息提示 //单独引用需修改的元素 import { message } from ant-design-vue message.config({maxCount: 1,getContainer:() > document.getElementById(showBigModal) || document.body //父组件…

Android-实现一个登录页面(kotlin)

准备工作 首先,确保你已经安装了 Android Studio。如果还没有安装,请访问 Android Studio 官网 下载并安装。 前提条件 - 安装并配置好 Android Studio Android Studio Electric Eel | 2022.1.1 Patch 2 Build #AI-221.6008.13.2211.9619390, built …

C++(继承中)

目录: 1.基类和派生类对象赋值转换 2.派生类当中的6个默认成员函数 --------------------------------------------------------------------------------------------------------------------------- 派生类对象可以赋值给 基类的对象/基类的指针/基类的引用&am…

Java每日一练(20230425)

目录 1. 乘积最大子数组 🌟🌟 2. 插入区间 🌟🌟 3. 删除有序数组中的重复项 II 🌟🌟 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每日一练 专栏 C/C每日一练 专栏…

CSGO搬砖,每天1-2小时,23年最强副业非它莫属(内附操作流程)

自从我学会了CSGO搬运,我发现生活也有了不小的改变,多了一份收入,生活质量也就提高了一份。 其实刚接触CSGO,我压根就不相信这么能挣钱,因为在印象中,游戏供玩家娱乐竞技的,作为我这种技术渣渣…

直播系统开发中如何优化API接口的并发

概述 在直播系统中,API接口并发的优化是非常重要的,因为它可以提高系统的稳定性和性能。本文将介绍一些优化API接口并发的方法。 理解API接口并发 在直播系统中,API接口是用于处理客户端请求的关键组件。由于许多客户端同时连接到系统&…

HTTP1.1(十二)Cookie的格式与约束

一 Cookie的格式与约束 ① Cookies是什么 1) cookie是我们在前端编程中经常使用的概念2) 使用cookie利用浏览器帮助我们保存客户的相关状态信息,保存用户已经做了什么事情3) 重点和难点[1]、cookie的工作原理[2]、cookie的限制是什么[3]、session又是怎样与cookie关联起来 …

90年三本程序员,8年5跳,年薪4万变92万……

很多时候,虽然跳槽可能带来降薪的结果,但依然有很多人认为跳槽可以涨薪。近日,看到一则帖子。 发帖的楼主表示,自己8年5跳,年薪4万到92万,现在环沪上海各一套房,再干5年码农,就可以…

【Vue】学习笔记-初始化脚手架

初始化脚手架 初始化脚手架说明具体步骤脚手架文件结构 初始化脚手架 说明 Vue脚手架是vue官方提供的标准化开发工具(开发平台)最新版本是4.x文档Vue CLI 具体步骤 如果下载缓慢请配置npm淘宝镜像 npm config set registry http://registry.npm.taoba…

【移动端网页布局】流式布局案例 ② ( 实现顶部固定定位提示栏 | 布局元素百分比设置 | 列表样式设置 | 默认样式设置 )

文章目录 一、样式测量及核心要点1、样式测量2、高度设定3、列表项设置4、设置每个元素的百分比宽度5、设置图像宽度 二、核心代码编写1、HTML 标签结构2、CSS 样式3、展示效果 三、完整代码示例1、HTML 标签结构2、CSS 样式3、展示效果 一、样式测量及核心要点 1、样式测量 京…

【ChatGPT】如何让 ChatGPT 不再频繁报错,获取更加稳定的体验?

文章目录 一、问题描述二、方案1:使用 OpenAI API Key 来访问 ChatGPT三、方案2:安装 Chrome 插件3.1 介绍3.2 安装步骤3.2.1 插件 & 脚本安装3.2.2 解读功能 一、问题描述 最近一段时间,相信大家都发现了 ChatGPT 一个问题,…

Unity音量滑块沿弧形移动

一、音量滑块的移动 1、滑块在滑动的时候,其运动轨迹沿着大圆的弧边展开 2、滑块不能无限滑动,而是两端各有一个挡块,移动到挡块位置,则不能往下移动,但可以折回 3、鼠标悬停滑块时,给出音量值和操作提示 …

前端 百度地图绘制路线加上图片

使用百度官方示例的方法根据起终点经纬度查询驾车路线但是只是一个线路 <template><div class"transportInfo"><div id"mapcontainer" class"map">11</div><div class"collapse"><el-collapse v-mo…

克隆Linux系统(centos)

克隆前得保证你有一台Linux系统的虚拟机了。 如果没有&#xff0c;可以参考这篇文章&#xff1a; 安装VMware虚拟机、Linux系统&#xff08;CentOS7&#xff09;_何苏三月的博客-CSDN博客 按照示意图一步一步执行即可。 克隆前先关闭运行的虚拟机系统。 然后右键已安装的虚拟…

【图像抠图】【深度学习】Ubuntu18.04下GFM官方代码Pytorch实现

【图像抠图】【深度学习】Ubuntu18.04下GFM官方代码Pytorch实现 提示:最近开始在【图像抠图】方面进行研究,记录相关知识点,分享学习中遇到的问题已经解决的方法。 文章目录 【图像抠图】【深度学习】Ubuntu18.04下GFM官方代码Pytorch实现前言数据集说明1.AM-2k【自然动物】2.B…

Ubuntu更新软件下载更新与移除

目录 一、更新软件源 二、下载与安装软件 三、如何移除软件 四、Ubuntu商店下载软件 一、更新软件源 更新Ubuntu软件源的操作步骤&#xff0c;更新软件源的目的就是&#xff0c;将在Ubuntu官网的软件源更改到本地&#xff0c;也就是国内的软件源&#xff0c;这样的话下载安…