pytorch 中RNN接口参数

news/2024/7/27 7:30:27/文章来源:https://blog.csdn.net/weixin_42924890/article/details/136565982

torch中RNN模块详细接口参数解析

rnn = torch.nn.RNN(input_size: int,hidden_size: int,num_layers: int = 1,nonlinearity: str = 'tanh',bias: bool = True,batch_first: bool = False,dropout: float = 0.0,bidirectional: bool = False,
)

input_size (int):输入序列中每个时间步的特征维度。

hidden_size (int):隐藏状态(记忆单元)的维度。

num_layers (int, 默认为1):RNN 层的堆叠数量。

nonlinearity (str, 默认为’tanh’):激活函数的选择,可以是 ‘tanh’ 或 ‘relu’。不过在标准 RNN 中通常使用 ‘tanh’。

bias (bool, 默认为True):是否在计算中包含偏置项。

batch_first (bool, 默认为False):如果设为 True,则输入和输出张量的第一个维度将被视为批次大小,而不是时间步长。即数据格式为 (batch_size, seq_len, input_size) 而不是 (seq_len, batch, input_size)。

dropout (float, 默认为0.0):应用于隐层到隐层之间的失活率,用于正则化以防止过拟合。只有当 num_layers > 1 时才会生效。

bidirectional (bool, 默认为False):若设置为 True,将会创建一个双向 RNN,这样模型可以同时处理过去和未来的上下文信息。
注意torch.nn.RNN 本身并不直接支持双向模式;要实现双向RNN,应使用 torch.nn.Bidirectional 包装器包裹一个单向RNN。

outputs, hn = rnn(...)

outputs: Tensor 如果batch_first=True,则为则为 (batch_size, seq_len, num_directions * hidden_size)。否则 (seq_len, batch_size, num_directions * hidden_size);
RNN 对输入序列每个时间步的输出。对于双向 RNN,num_directions 为2,输出是正向和反向隐藏状态的串联或拼接结果。

hn: Tensor (h_n 或 hidden):形状(num_layers * num_directions, batch_size, hidden_size)
最后一个时间步的隐藏状态(或者在双向情况下,正向和反向隐藏状态)。等价 output[:, -1, :]

实例化一个单向的RNN单元

import torch.nn as nn
import torchbatch_size = 2
seq_len = 7
input_size = 5
hidden_size = 3
num_layers = 1rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity="tanh", batch_first=True)# (batch_size, seq_len, input_size) 来两个样本为一批,每个样本在时序上分7步,每一步的维度是5
input = torch.randn(batch_size, seq_len, input_size)# (num_layers, batch_size, hidden_size) torch源码默认全零,建议使用默认值
h0 = torch.randn(1, 2, 3)# output = (batch_size, seq_len, hidden_size)
# hn = (num_layers, batch_size, hidden_size)
# hn 每一个样本最后一步的信息,等价 output[:,-1,:]
# !!注意此处变量的维度大小都是基于本例计算的,并不是实际计算公式!!
output, hn = rnn(input, h0)
# print(output)
print(output.shape)  # torch.Size([2, 7, 3])
# print(hn)
print(hn.shape)  # torch.Size([1, 2, 3])
# print(output[:, -1, :])

为了方便复习现将源码参数说明附在这里

class RNN(RNNBase):
r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
input sequence.
For each element in the input sequence, each layer computes the following
function:
.. math::h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
previous layer at time `t-1` or the initial hidden state at time `0`.
If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.Args:input_size: The number of expected features in the input `x`hidden_size: The number of features in the hidden state `h`num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``would mean stacking two RNNs together to form a `stacked RNN`,with the second RNN taking in outputs of the first RNN andcomputing the final results. Default: 1nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.Default: ``True``batch_first: If ``True``, then the input and output tensors are providedas `(batch, seq, feature)` instead of `(seq, batch, feature)`.Note that this does not apply to hidden or cell states. See theInputs/Outputs sections below for details.  Default: ``False``dropout: If non-zero, introduces a `Dropout` layer on the outputs of eachRNN layer except the last layer, with dropout probability equal to:attr:`dropout`. Default: 0bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``Inputs: input, h_0* **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,:math:`(L, N, H_{in})` when ``batch_first=False`` or:math:`(N, L, H_{in})` when ``batch_first=True`` containing the features ofthe input sequence.  The input can also be a packed variable length sequence.See :func:`torch.nn.utils.rnn.pack_padded_sequence` or:func:`torch.nn.utils.rnn.pack_sequence` for details.* **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or:math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hiddenstate for the input sequence batch. Defaults to zeros if not provided.where:.. math::\begin{aligned}N ={} & \text{batch size} \\L ={} & \text{sequence length} \\D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\H_{in} ={} & \text{input\_size} \\H_{out} ={} & \text{hidden\_size}\end{aligned}Outputs: output, h_n* **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,:math:`(L, N, D * H_{out})` when ``batch_first=False`` or:math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features`(h_t)` from the last layer of the RNN, for each `t`. If a:class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the outputwill also be a packed sequence.* **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or:math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden statefor each element in the batch.

参考 torch.nn.RNN 源码接口文档。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_999215.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鼠标右键没有git bash here,右键添加git bash here并增加图标

突然发现自己鼠标右键没有git bash here,或者安装之后就没有git bash here。后面那种情况多半是没有默认装在C盘。我们装在其他盘的时候就需要自己去配置。git gui目前用不上,这里只讲git bash here。网上一堆教程,说法不同大多不能用要么就很…

Docker部署SimpleMindMap结合内网穿透实现公网访问本地思维导图

文章目录 1. Docker一键部署思维导图2. 本地访问测试3. Linux安装Cpolar4. 配置公网地址5. 远程访问思维导图6. 固定Cpolar公网地址7. 固定地址访问 SimpleMindMap 是一个可私有部署的web思维导图工具。它提供了丰富的功能和特性,包含插件化架构、多种结构类型&…

1-Git-基础

版本控制 为什么要进行版本控制? **个人角度:**代码的修改繁杂,如果一点点修改代码,不利于开发。进行版本控制,可以在特定历史状态下进行修改。即可以进行回退和撤销操作 **团队角度:**每个人负责各自的…

垃圾收集器底层算法

垃圾收集器底层算法 三色标记 在并发标记的过程中,因为标记期间应用线程还在继续跑,对象间的引用可能发生变化,多标和漏标的情况就有可能发生,这里我们引入“三色标记”来给大家解释下把Gcroots可达性分析遍历对象过程中遇到对象…

灯塔:CSS笔记(2)

一 选择器进阶 后代选择器:空格 作用:根据HTML标签的嵌套关系,,选择父元素 后代中满足条件的元素 选择器语法:选择器1 选择器2{ css } 结果: *在选择器1所找到标签的后代(儿子 孙子 重孙子…

opencv dnn模块 示例(24) 目标检测 object_detection 之 yolov8-pose 和 yolov8-obb

前面博文【opencv dnn模块 示例(23) 目标检测 object_detection 之 yolov8】 已经已经详细介绍了yolov8网络和测试。本文继续说明使用yolov8 进行 人体姿态估计 pose 和 旋转目标检测 OBB 。 文章目录 1、Yolov8-pose 简单使用2、Yolov8-OBB2.1、python 命令行测试2.2、opencv…

iOS-系统弹窗调用

代码: UIAlertController *alertViewController [UIAlertController alertControllerWithTitle:"请选择方式" message:nil preferredStyle:UIAlertControllerStyleActionSheet];// style 为 sheet UIAlertAction *cancle [UIAlertAction actionWithTit…

Unity性能优化篇(七) UI优化注意事项以及使用Sprite Atlas打包精灵图集

UI优化注意事项 1.尽量避免使用IMGUI(OnGUI)来做游戏时的UI,因为IMGUI的开销比较大。 2.如果一个UGUI的控件不需要进行射线检测,则可以取消勾选Raycast Target 3.尽量避免使用完全透明的图片和UI控件。因为即使完全透明,我们看不见它&#xf…

C#,老鼠迷宫问题的回溯法求解(Rat in a Maze)算法与源代码

1 老鼠迷宫问题 迷宫中的老鼠,作为另一个可以使用回溯解决的示例问题。 迷宫以块的NN二进制矩阵给出,其中源块是最左上方的块,即迷宫[0][0],目标块是最右下方的块,即迷宫[N-1][N-1]。老鼠从源头开始,必须…

Docker安装主从数据库

我自己的主数据库名字 user_muster 密码是123456 从数据库 就是slave1 名字是root 密码是123456 首先开启docker后直接执行命令 docker run -d \ -p 3307:3306 \ -v /xk857/mysql/master/conf:/etc/mysql/conf.d \ -v /xk857/mysql/master/data:/var/lib/mysql \ -e MYSQL_R…

JavaWeb笔记 --- 一JDBC

一、JDBC JDBC就是Java操作关系型数据库的一种API DriverManager 注册驱动可以不写 Class.forName("com.mysql.jdbc.Driver"); Connection Statement ResultSet PrepareStatement 密码输入一个SQL脚本,直接登录 预编译开启在url中 数据库连接池

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:ImageAnimator)

提供帧动画组件来实现逐帧播放图片的能力,可以配置需要播放的图片列表,每张图片可以配置时长。 说明: 该组件从API Version 7开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。 子组件 无 接口 ImageAni…

GO语言接入支付宝

GO语言接入支付宝 今天就go语言接入支付宝写一个教程 使用如下库,各种接口较为齐全 "github.com/smartwalle/alipay/v3"先简单介绍下加密: 试想,当用户向支付宝付款时,若不进行任何加密,那么黑客就可以任…

C++:模版进阶 | Priority_queue的模拟实现

创作不易,感谢三连支持 一、非类型模版参数 模板参数分类为类型形参与非类型形参。 类型形参即:出现在模板参数列表中,跟在class或者typename之类的参数类型名称。 非类型形参,就是用一个常量作为类(函数)模板的一个参数&…

智能音箱技术解析

目录 前言智能音箱执行步骤解析1.1 探测唤醒词或触发词1.2 语音识别1.3 意图识别1.4 执行指令 2 典型的智能音箱2.1 百度小度音响2.2 小米小爱同学2.3 苹果 HomePod 3 功能应用举例3.1 设置计时器3.2 播放音乐 结语 前言 智能音箱已经成为日常生活中不可或缺的一部分&#xff…

解决方案|珈和科技推出农业特色产业数字化服务平台

今年中央一号文件提出,鼓励各地因地制宜大力发展特色产业,支持打造乡土特色品牌。 然而,农业特色产业的生产、加工和销售仍然面临诸多挑战。产品优质不能优价,优质不能优用的现象屡见不鲜,产业化程度低、生产附加值不…

【SpringMVC】快速体验 SpringMVC接收数据 第一期

文章目录 一、SpringMVC 介绍1.1 主要作用1.2 核心组件和调用流程理解 二、快速体验三、SpringMVC接收数据3.1 访问路径设置3.1.1 精准路径匹配3.1.2 模糊路径匹配3.1.3 类和方法级别区别3.1.4 附带请求方式限制3.1.5 进阶注解 与 常见配置问题 3.2 接收参数(重点&a…

mxxWechatBot微信机器人说明

大家伙,我是雄雄,欢迎关注微信公众号:雄雄的小课堂。 免责声明:该工具仅供学习使用,禁止使用该工具从事违法活动,否则永久拉黑封禁账号!!!本人不对任何工具的使用负责&am…

粘包与拆包

优质博文:IT-BLOG-CN 一、粘包出现的原因 服务端与客户端没有约定好要使用的数据结构。Socket Client实际是将数据包发送到一个缓存buffer中,通过buffer刷到数据链路层。因服务端接收数据包时,不能断定数据包1何时结束,就有可能出…

吴恩达deeplearning.ai:机器学习的开发过程与优化方法

以下内容有任何不理解可以翻看我之前的博客哦:吴恩达deeplearning.ai专栏 我想在接下来分析下开发机器学习系统的过程,这样当你自己动手时,能够做出更加正确的判断。 机器学习开发的迭代 Iterative loop of ML development 决定模型架构 第…