【Pytorch】进阶学习:基于矩阵乘法torch.matmul()实现全连接层

news/2024/5/25 9:07:20/文章来源:https://blog.csdn.net/qq_41813454/article/details/136556212

【Pytorch】进阶学习:基于矩阵乘法torch.matmul()实现全连接层

在这里插入图片描述

🌈 个人主页:高斯小哥
🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程👈 希望得到您的订阅和支持~
💡 创作高质量博文(平均质量分92+),分享更多关于深度学习、PyTorch、Python领域的优质内容!(希望得到您的关注~)


🌵文章目录🌵

  • 🚀一、引言
  • 🔍二、全连接层的基本原理
  • 🔩三、使用torch.matmul()实现全连接层
  • 🎛️四、使用PyTorch的nn.Linear模块实现全连接层
  • 🔎五、小结与注意事项
  • 🤝六、实战演练:构建简单的神经网络
  • 📚七、进阶学习:深度神经网络与全连接层
  • 🤝八、期待与你共同进步

🚀一、引言

  在深度学习的世界里,全连接层(Fully Connected Layer)是构建神经网络的基础组件之一。它实际上执行的就是矩阵乘法操作,将输入数据映射到输出空间。在PyTorch中,我们可以使用torch.matmul()函数来实现这一操作。本文将详细解释如何使用torch.matmul()实现全连接层,并通过实例展示其应用。

🔍二、全连接层的基本原理

  全连接层,也称为密集连接层或仿射层,其核心操作就是矩阵乘法。假设输入数据的形状为(batch_size, input_features),全连接层的权重矩阵形状为(output_features, input_features),偏置项的形状为(output_features,)。全连接层的输出可以通过以下公式计算得到:

output = input @ weight.t() + bias

这里,@ 表示矩阵乘法,.t() 表示转置操作。注意,权重矩阵的列数必须与输入数据的特征数相匹配,以便进行矩阵乘法。偏置项则是一个可选的加法操作,用于增加模型的灵活性。

🔩三、使用torch.matmul()实现全连接层

在PyTorch中,我们可以使用torch.matmul()函数来执行矩阵乘法操作,从而实现全连接层。下面是一个简单的示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F# 定义全连接层的输入和输出特征数
input_features = 10
output_features = 5# 创建一个随机的输入张量,形状为(batch_size, input_features)
batch_size = 32
input_tensor = torch.randn(batch_size, input_features)# 初始化全连接层的权重和偏置项
weight = torch.randn(output_features, input_features)
bias = torch.randn(output_features)# 使用torch.matmul()实现全连接层的计算
output_tensor = torch.matmul(input_tensor, weight.t()) + bias# 查看输出张量的形状,应为(batch_size, output_features)
print(output_tensor.shape)  # 输出应为torch.Size([32, 5])

  在上面的代码中,我们首先定义了全连接层的输入和输出特征数。然后,我们创建了一个随机的输入张量input_tensor,其形状为(batch_size, input_features)。接下来,我们初始化了全连接层的权重weight和偏置项bias。最后,我们使用torch.matmul()函数执行矩阵乘法操作,并将结果加上偏置项,得到输出张量output_tensor。通过打印输出张量的形状,我们可以验证其是否符合预期。

🎛️四、使用PyTorch的nn.Linear模块实现全连接层

  虽然我们可以使用torch.matmul()手动实现全连接层,但在实际开发中,更常见的是使用PyTorch提供的nn.Linear模块来创建全连接层。这个模块封装了权重和偏置项的初始化、矩阵乘法以及偏置项的加法操作,使得全连接层的实现更加简洁和方便。

下面是一个使用nn.Linear模块实现全连接层的示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F# 定义全连接层的输入和输出特征数
input_features = 10
output_features = 5# 创建一个随机的输入张量,形状为(batch_size, input_features)
batch_size = 32
input_tensor = torch.randn(batch_size, input_features)# 使用nn.Linear模块创建全连接层
linear_layer = nn.Linear(input_features, output_features)# 将输入张量传递给全连接层进行计算
output_tensor = linear_layer(input_tensor)# 查看输出张量的形状
print(output_tensor.shape)  # 输出应为torch.Size([32, 5])

  在上面的代码中,我们直接使用nn.Linear(input_features, output_features)创建了一个全连接层对象linear_layer。然后,我们将输入张量input_tensor传递给这个全连接层对象,即可得到输出张量output_tensor。这种方式比手动使用torch.matmul()更加简洁,同时也提供了更多的功能和灵活性,例如权重和偏置项的初始化方法、是否包含偏置项等。

🔎五、小结与注意事项

  通过本文的介绍,我们了解了全连接层的基本原理,并学习了如何使用torch.matmul()函数以及nn.Linear模块来实现全连接层。在实际应用中,我们可以根据具体需求选择合适的方式来实现全连接层。需要注意的是,在使用torch.matmul()时,要确保输入张量和权重矩阵的形状匹配,以避免出错。

🤝六、实战演练:构建简单的神经网络

  理解了全连接层的工作原理和如何使用torch.matmul()后,我们可以进一步构建一个简单的神经网络来加深理解。以下是一个使用PyTorch构建和训练简单神经网络的示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 定义全连接层的输入和输出特征数
input_features = 10
output_features = 1batch_size = 32# 假设的输入和输出数据
X_train = torch.randn(100, input_features)
y_train = torch.randint(0, 2, (100,))  # 假设是二分类问题# 将数据包装成TensorDataset和DataLoader
dataset = TensorDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 定义简单的神经网络模型
class SimpleNN(nn.Module):def __init__(self, input_dim, output_dim):super(SimpleNN, self).__init__()self.fc = nn.Linear(input_dim, output_dim)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc(x)x = self.sigmoid(x)return x# 初始化模型、损失函数和优化器
model = SimpleNN(input_features, output_features)
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):for inputs, targets in dataloader:# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs.squeeze(), targets.float())# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 测试模型
with torch.no_grad():test_data = torch.randn(5, input_features)predictions = model(test_data)print(predictions)

  在上面的代码中,我们首先定义了一个简单的神经网络模型SimpleNN,它只包含一个全连接层和一个Sigmoid激活函数。然后,我们初始化了模型、损失函数(二分类交叉熵损失)和优化器(随机梯度下降)。接着,我们进行了模型的训练过程,包括前向传播、损失计算、反向传播和参数更新。最后,我们对模型进行了测试,输入了一些随机生成的数据并得到了预测结果。

📚七、进阶学习:深度神经网络与全连接层

  全连接层在深度神经网络中扮演着重要的角色。随着网络深度的增加,全连接层可以帮助模型捕获更复杂的特征和模式。然而,在实际应用中,我们还需要注意一些问题,如过拟合、计算效率等。为了解决这些问题,我们可以采用一些技巧和方法,如添加正则化项、使用Dropout层、优化网络结构等。

  此外,随着深度学习技术的不断发展,越来越多的新型网络结构被提出,如卷积神经网络(CNN)、循环神经网络(RNN)等。这些网络结构在处理图像、语音、文本等不同类型的数据时具有独特的优势。因此,我们可以进一步学习这些网络结构,并结合全连接层来构建更强大的深度学习模型。

🤝八、期待与你共同进步

  🌱 亲爱的读者,非常感谢你每一次的停留和阅读!你的支持是我们前行的最大动力!🙏

  🌐 在这茫茫网海中,有你的关注,我们深感荣幸。你的每一次点赞👍、收藏🌟、评论💬和关注💖,都像是明灯一样照亮我们前行的道路,给予我们无比的鼓舞和力量。🌟

  📚 我们会继续努力,为你呈现更多精彩和有深度的内容。同时,我们非常欢迎你在评论区留下你的宝贵意见和建议,让我们共同进步,共同成长!💬

  💪 无论你在编程的道路上遇到什么困难,都希望你能坚持下去,因为每一次的挫折都是通往成功的必经之路。我们期待与你一起书写编程的精彩篇章! 🎉

  🌈 最后,再次感谢你的厚爱与支持!愿你在编程的道路上越走越远,收获满满的成就和喜悦!祝你编程愉快!🎉

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1007164.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

19、设计模式之中介者模式(Mediator)

一、什么是中介者模式 中介者模式是一种行为型设计模式,它用于减少对象之间互相通信的复杂性。中介者模式通过创建一个中介者对象,将对象之间的通信集中交给该对象来处理,而不是直接相互交流,是符合迪米特原则的典型应用。 迪米特…

物理机win10怎么与虚拟机win10共享文件

打开win10虚拟机点击虚拟机选项安装vmTools 安装完成后系统会重启重启后关机 点击编辑虚拟机设置 选项、共享文件夹、总是启用 接下来点击添加选择你要共享的文件点击确定 打开虚拟机点击此电脑 就会看到共享的文件夹啦

Milvus的相似度指标

官网:https://milvus.io/docs/metric.md版本: v2.3.x 在 Milvus 中,相似度度量用于衡量向量之间的相似度。选择良好的距离度量有助于显着提高分类和聚类性能。下表展示了这些广泛使用的相似性指标如何与各种输入数据形式和 Milvus 索引相匹配。 一、浮…

【C#图解教程】笔记

文章目录 1. C#和.NET框架.NET框架的组成.NET框架的特点CLRCLICLI的重要组成部分各种缩写 2. C#编程概括标识符命名规则: 多重标记和值格式化数字字符串对齐说明符格式字段标准数字格式说明符标准数字格式说明符 表 3. 类型、存储和变量数据成员和函数成员预定义类型…

第二门课:改善深层神经网络<超参数调试、正则化及优化>-深度学习的实用层面

文章目录 1 训练集、验证集以及测试集2 偏差与方差3 机器学习基础4 正则化5 为什么正则化可以减少过拟合&#xff1f;6 Dropout<随机失活>正则化7 理解Dropout8 其他正则化方法9 归一化输入10 梯度消失和梯度爆炸11 神经网络的权重初始化12 梯度的数值逼近13 梯度检验14 …

IDEA自定义Maven仓库

Maven 是一款广泛应用于 Java 开发的工具&#xff0c;其作用类似于一个全自动的 JAR 包管理器&#xff0c;能够方便地导入开发所需的相关 JAR 包。在使用 Maven 进行 Java 程序开发时&#xff0c;开发者能够极大地提高开发效率。以下是关于如何安装 Maven 以及在 IDEA 中配置自…

iOS——【自动引用计数】ARC规则及实现

1.3.3所有权修饰符 所有权修饰符一共有四种&#xff1a; __strong 修饰符__weak 修饰符__undafe_unretained 修饰符__autoreleasing 修饰符 __strong修饰符 _strong修饰符表示对对象的强引用&#xff0c;持有强引用的变量在超出其作用域的时候会被废弃&#xff0c;随着强引…

③【Docker】Docker部署Nginx

个人简介&#xff1a;Java领域新星创作者&#xff1b;阿里云技术博主、星级博主、专家博主&#xff1b;正在Java学习的路上摸爬滚打&#xff0c;记录学习的过程~ 个人主页&#xff1a;.29.的博客 学习社区&#xff1a;进去逛一逛~ ③【Docker】Docker部署Nginx docker拉取nginx…

二、应用层

二、应用层 2.1 应用层协议原理 可能用的应用架构&#xff1a; 1.C/S模式&#xff1a;用户增加&#xff0c;性能断崖式下降 2.P2P体系结构 3.混合体 进程通信&#xff1a; 进程—在主机上运行的应用程序 在同一个主机内&#xff0c;使用进程间通信机制通信&#xff08;操作…

Kubernetes弃用Dockershim,转向Containerd:影响及如何应对

Kubernetes1.24版本发布时&#xff0c;正式宣布弃用Dockershim&#xff0c;转向Containerd作为默认的容器运行环境。Kubernetes以CRI(Container Runtime Interface)容器运行时接口制定接入准则&#xff0c;用户可以使用Containerd、CRI-O、CRI- Dockerd及其他容器运行时作为Kub…

Solidity 智能合约开发 - 基础:基础语法 基础数据类型、以及用法和示例

苏泽 大家好 这里是苏泽 一个钟爱区块链技术的后端开发者 本篇专栏 ←持续记录本人自学两年走过无数弯路的智能合约学习笔记和经验总结 如果喜欢拜托三连支持~ 本篇主要是做一个知识的整理和规划 作为一个类似文档的作用 更为简要和明了 具体的实现案例和用法 后续会陆续给出…

netty草图笔记

学一遍根本记不住&#xff0c;那就再学一遍 public static void test_nettyFuture() {NioEventLoopGroup group new NioEventLoopGroup();log.info("开始提交任务");Future<String> future group.next().submit(() -> {log.info("执行异步任…

webmagic面试准备

1.什么是webmagic WebMagic是一款开源的Java爬虫框架&#xff0c;旨在简化网络爬虫的开发过程&#xff0c;使开发者更加高效便捷的构建网络爬虫程序。它采用了模块化的设计思想&#xff0c;将爬虫的整个生命周期划分为了四个核心组件&#xff1a;Downloader、PageProcessor、Sc…

git提交代码到仓库

git提交代码到仓库 当代码写到一半想提交到新仓库时 平常在练习时&#xff0c;写了一半的代码要提交仓库怎么做 创建一个新仓库&#xff0c;到下面图片时&#xff0c;注意红框内的代码 这种情况是已有仓库的&#xff0c;在执行git命令前 在代码中一次执行 git initgit add…

算法空间复杂度计算

目录 空间复杂度定义 影响空间复杂度的因素 算法在运行过程中临时占用的存储空间讲解 例子 斐波那契数列递归算法的性能分析 二分法&#xff08;递归实现&#xff09;的性能分析 空间复杂度定义 空间复杂度(Space Complexity)是对一个算法在运行过程中临时占用存储空间大…

电脑干货:6款免费的实用工具,值得收藏

目录 1、HelloWindows 2、Memory Helper 3、MindNode 4、B站视频下载工具 5、wallhaven壁纸 1、HelloWindows HelloWindows是一个纯净Windows系统下载网站&#xff0c;它可以下载到所有Windows系统源文件&#xff0c;比如Windows11、Windows10、win7、XP等&#xff0c;也可…

0基础安装Burpsuit专业版

首先先安装java环境,安装jdk 11的版本 文件中2023版的可以直接点开使用不需要复杂的操作的步骤 资源获取链接&#xff1a; 链接&#xff1a;百度网盘 请输入提取码 提取码&#xff1a;k2qq 其中&#xff1a;1号文件是bp的英文版激活包&#xff0c;-2号是中文版汉化版的激活包…

基于FPGA加速的bird-oid object算法实现

导语 今天继续康奈尔大学FPGA 课程ECE 5760的典型案例分享——基于FPGA加速的bird-oid object算法实现。 &#xff08;更多其他案例请参考网站&#xff1a; Final Projects ECE 5760&#xff09; 1. 项目概述 项目网址 ECE 5760 Final Project 模型说明 Bird-oid object …

Tree Shaking:优化前端项目的利器

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

汽车IVI中控开发入门及进阶(十四):功能安全

前言: 是时候需要来说一下功能安全了,有没有发现现在很多主机厂、Tier1对芯片等BOM物料有些是有功能安全需求的,那么什么是功能安全呢? 车辆中电子元件数量的增加增加了更多故障的可能性,对驾驶员和乘客的风险更高。这种风险的增加导致汽车行业将功能安全标准作为汽车设计…