自编码器简单介绍—使用PyTorch库实现一个简单的自编码器,并使用MNIST数据集进行训练和测试

news/2024/5/4 13:35:22/文章来源:https://blog.csdn.net/qq_36693723/article/details/130316479

文章目录

    • 自编码器简单介绍
    • 什么是自编码器?
    • 自动编码器和卷积神经网络的区别?
    • 如何构建一个自编码器?
    • 如何训练自编码器?
    • 如何使用自编码器进行图像压缩?
    • 总结
    • 使用PyTorch构建简单的自动编码器
      • 第一步:导入库和数据集
      • 第二步:建立编码器和解码器
      • 第三步:定义损失函数和优化器
      • 第四步:训练自编码器模型
      • 第五步:测试自编码器模型
      • 第六部:对比重构结果

自编码器简单介绍

自编码器是一种无监督学习算法,用于学习数据中的特征,并将这些特征用于重构与输入相似的新数据。自编码器由编码器和解码器两部分组成,编码器用于将输入数据压缩到一个低维度的表示形式,解码器将该表示形式还原回输入数据的形式。自编码器可以应用于多种领域,例如图像处理、语音识别和自然语言处理等。

什么是自编码器?

自编码器是一种无监督学习算法,用于学习数据的压缩表示。它由两个部分组成:编码器和解码器。编码器将输入数据压缩成低维表示,解码器将这个低维表示重构成与原始输入尽可能接近的输出。
简单结构如下:

自动编码器和卷积神经网络的区别?

自动编码器和卷积神经网络(CNN)都是深度学习中常用的模型,但它们的目的和结构略有不同。

自动编码器是一种无监督学习模型,其目的是学习一个对输入数据进行压缩和解压缩的函数。自动编码器通常由两个部分组成:编码器和解码器。编码器将输入数据转换为潜在表示,解码器将潜在表示转换回原始数据。自动编码器的目标是最小化重构误差,即在解码器输出的数据与原始数据之间的差异。

与之不同,卷积神经网络是一种用于图像分类、目标检测、语音识别等任务的有监督学习模型。卷积神经网络通常由卷积层、池化层、全连接层等组成。卷积层可以捕捉输入数据的局部特征,池化层可以减少特征图的大小,全连接层可以将特征图转换为分类结果。

虽然自动编码器和卷积神经网络的目的和结构不同,但它们都可以用于特征提取。自动编码器可以学习输入数据的低维表示,卷积神经网络可以提取输入数据的局部特征。在某些任务中,这些特征可以作为输入传递到其他模型中,以提高任务的性能。

除了特征提取,自动编码器和卷积神经网络还有其他方面的不同之处。

自动编码器可以用于数据的降维和去噪。通过学习输入数据的低维表示,自动编码器可以将高维数据降低到更低的维度,从而简化数据。此外,自动编码器还可以在输入数据中去除噪声,因为它们的目标是在重构时最小化重构误差,这有助于滤除输入数据中的噪声。

卷积神经网络则更适合处理图像、语音、视频等具有空间结构的数据。卷积层可以捕捉输入数据的局部特征,这对于图像分类、目标检测等任务非常重要。此外,卷积神经网络还可以通过使用池化层来减少特征图的大小,从而在处理大型图像时降低计算成本。

总的来说,自动编码器和卷积神经网络是两种不同的深度学习模型,它们的应用场景和目标略有不同。但是,它们都是非常有用的工具,可以在各种任务中发挥重要作用。

如何构建一个自编码器?

首先,需要确定编码器和解码器的结构。编码器可以是多层感知器(MLP)或卷积神经网络(CNN),其中每一层都包含多个神经元,并通过非线性函数进行激活。解码器通常与编码器结构相对称,也可以是一个MLP或CNN。在训练自编码器时,将输入数据输入编码器并计算编码器输出,然后将其输入解码器并计算解码器输出。最终的目标是最小化解码器输出和原始输入之间的差异。

如何训练自编码器?

训练自编码器的关键是确定损失函数。通常使用均方误差(MSE)来计算解码器输出和原始输入之间的差异。MSE的计算方式如下:

$ 在公式中, y i y_i yi 是实际值, y i ^ \hat{y_i} yi^ 是预测值, N N N 是数据集中的样本数。
M S E = 1 N ∑ i = 1 N ( y i − y ^ i ) 2 \mathrm MSE=\frac{1}{N}\sum_{i=1}^N(y_i-\hat y_i)^2 MSE=N1i=1N(yiy^i)2
其中,N是样本数量, y i y_i yi是实际值, y ^ i \hat y_i y^i是解码器输出,即预测值。通过反向传播算法,可以计算编码器和解码器中所有参数的梯度,并使用梯度下降算法对参数进行更新。

如何使用自编码器进行图像压缩?

在用自编码器压缩图像时,将图像输入自编码器的编码器部分,得到低维表示,该表示可以看作是图像的压缩版本。可以通过解码器部分将该低维表示解码成原始图像。由于自编码器的解码器部分重构图像,因此可以使用自编码器进行图像压缩。

总结

以上是自编码器的简单入门教程。自编码器是一种无监督学习算法,可以用于学习数据的压缩表示。在训练自编码器时需要确定编码器和解码器的结构以及损失函数,并使用梯度下降算法对参数进行更新。自编码器可以用于图像压缩等任务。

使用PyTorch构建简单的自动编码器

在本教程中,我们将使用PyTorch库实现一个简单的自编码器,并使用MNIST数据集进行训练和测试。

第一步:导入库和数据集

首先,我们需要导入必要的库并加载MNIST数据集。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 检查是否安装了CUDA,并且CUDA是否适用于你的GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 超参数设置
batch_size = 64
learning_rate = 1e-3
num_epochs = 10# 加载MNIST手写数字数据集
train_dataset = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor())# 创建数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

第二步:建立编码器和解码器

接下来,我们需要建立编码器和解码器模型。

# 定义一个自编码器的类
class Autoencoder(nn.Module):def __init__(self):super(Autoencoder, self).__init__()# 编码器部分self.encoder = nn.Sequential(nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # 输入1通道,输出16通道,3x3卷积,步长为2,padding为1nn.ReLU(),nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 输入16通道,输出32通道,3x3卷积,步长为2,padding为1nn.ReLU(),nn.Conv2d(32, 64, kernel_size=7),                      # 输入32通道,输出64通道,7x7卷积)# 解码器部分self.decoder = nn.Sequential(nn.ConvTranspose2d(64, 32, kernel_size=7),             # 输入64通道,输出32通道,7x7卷积转置nn.ReLU(),nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # 输入32通道,输出16通道,3x3卷积转置,步长为2,padding为1,输出padding为1nn.ReLU(),nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 输入16通道,输出1通道,3x3卷积转置,步长为2,padding为1,输出padding为1nn.Sigmoid())def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x# 创建自编码器对象并将模型移动到GPU
model = Autoencoder().to(device)

第三步:定义损失函数和优化器

定义损失函数和优化器如下。

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

第四步:训练自编码器模型

现在可以使用MNIST数据集训练自编码器模型。

# 训练自编码器
for epoch in range(num_epochs):for data in train_loader:img, _ = dataimg = img.to(device)# 前向传播output = model(img)loss = criterion(output, img)# 反向传播和优化器优化optimizer.zero_grad()loss.backward()optimizer.step()print("Epoch[{}/{}], loss:{:.4f}".format(epoch+1, num_epochs, loss.data))

第五步:测试自编码器模型

# 测试自编码器
model.eval()
total_loss = 0
with torch.no_grad():for data in test_loader:img, _ = dataimg = img.to(device)output = model(img)loss = criterion(output, img)total_loss += loss.item()print("Test average loss:{:.4f}".format(total_loss / len(test_loader)))

第六部:对比重构结果

在训练完成后,我们可以使用自编码器模型重构一些测试数据,并将重构的结果与原始数据进行比较。

# 迭代测试数据集,生成迭代器
dataiter = iter(test_loader)# 从迭代器中获取下一个批次的图像和标签
images, labels = next(dataiter)# 使用模型进行推断,处理获取的图像数据,并将结果保存在output变量中
output = model(images.to(device))# 创建子图和轴对象,其中第一行显示原始图像,第二行显示重构后的图像
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))# 循环遍历前10个图像,绘制原始图像和重构图像并添加标题
for i in range(10):# 显示原始图像axes[0,i].imshow(images[i].squeeze().numpy(), cmap='gray')axes[0,i].set_title("Original")axes[0,i].get_xaxis().set_visible(False)axes[0,i].get_yaxis().set_visible(False)# 显示重构后的图像axes[1,i].imshow(output[i].squeeze().cpu().detach().numpy(), cmap='gray')axes[1,i].set_title("Reconstructed")axes[1,i].get_xaxis().set_visible(False)axes[1,i].get_yaxis().set_visible(False)# 显示生成的子图
plt.show()

运行完整的程序后,我们应该会看到原始图像和重构图像的对比结果。自编码器可以在其中一些图像上产生较好的结果,但在其他图像上可能会出现一些失真或模糊。

自编码器在实际应用中有许多变体和扩展,例如稀疏自编码器和卷积自编码器。这里仅仅是一个简单的入门教程,可以让您了解自编码器的基本原理和应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_102099.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JavaEE】社区版IDEA(2021.X版本及之前)创建SpringBoot项目

目录 下载Spring Boot Helper 创建项目 下载相关依赖 判断成功 删除多余文件 项目建好后添加依赖 输出Hello World SpringBoot的优点 下载Spring Boot Helper 创建项目 下载相关依赖 如果没有配置过国内源,参考【JavaEE】Spring项目的创建与使用_p_fly的博…

M_Map工具箱简介及地理图形绘制

M_Map工具箱简介及地理图形绘制 1 M_Map简介1.1 具体代码说明 2 地理图形绘制案例2.1 M_Map给定案例2.1.1 M_Map Logo2.1.2 Lambert Conformal Conic projection of North American Topography2.1.3 Stereographic projection of North Polar regions2.1.4 Colourmaps 2.2 案例…

【社区图书馆】 Go佬—Go程序开发实战宝典书评

文章目录 前言内容介绍文章大致划分总结 前言 《Go 程序开发实战宝典》是一本非常实用的 Go 语言开发工具书,本书由深入浅出的案例讲解、详细的技术实现、贴近实际的应用开发等组成,非常适合 Go 语言开发爱好者、从事相关行业的工程师、技术负责人以及深…

Spring依赖注入的三种方式使用及优缺点

初学Spring的时候,我们从Spring容器中获取Bean对象都是通过bean标签先将Bean对象注册到Spring容器中,然后通过上下文对象congtext的getBean方法进行获取,显然这种方法较为麻烦,所以有了更简单的存方法:五大类注解;取方…

Elasticsearch:了解和解决文档更新后 Elasticsearch 分数的变化

问题 问卷中有如下这样的文档,开发者想通过 match query 搜索这些文档来使用分数。 POST sample-index-test/_doc/1 {"first_name": "James","last_name" : "Osaka" } 以下是对上述文档的示例查询: GET sam…

贾其萃 : 笃行实践 筑梦扬帆 | 提升之路系列(二)

导读 为了发挥清华大学多学科优势,搭建跨学科交叉融合平台,创新跨学科交叉培养模式,培养具有大数据思维和应用创新的“π”型人才,由清华大学研究生院、清华大学大数据研究中心及相关院系共同设计组织的“清华大学大数据能力提升项…

LeetCode特训 --- Week2 (主打滑动窗口 + 字符串匹配题目)

目录 滑动窗口原理 真懂了滑动窗口? 滑动 字符串细节 开干切题 滑动窗口原理 滑动窗口:维护一前一后两根指针, 或者说一左一右两个指针。更主要的是维护左右指针中的区间. 同时不断的向前滑动,直到整个序列滑动结束,前指针走到序列末尾…

有什么好用的远程工具吗

沟通在任何类型的工作中都扮演着重要的角色。但当谈到远程工作时,这一点就更为重要。因此,您的组织必须找到可以让您的团队保持一致的工具。 在某些方面,项目管理扮演着类似的角色。 您会注意到,下面的大多数工具都会直接影响您的…

Java核心技术 卷1-总结-10

Java核心技术 卷1-总结-10 通配符类型通配符概念通配符的超类型限定无限定通配符通配符捕获 通配符类型 通配符概念 通配符类型中&#xff0c;允许类型参数变化。 例如&#xff0c;通配符类型Pair<? extends Employee>表示任何泛型Pair类型&#xff0c;它的类型参数是…

【Linux】uptime命令详解平均负载

命令 ➜ ~ uptime 22:37 up 90 days, 21:45, 2 users, load averages: 2.91 3.46 3.81 具体含义 22:37&#xff1a;代表的是当前的系统时间&#xff0c;也即晚上10点37分。 up 90 days, 21:45&#xff1a;代表系统运行时间 2 users &#xff1a;当前两个用户 load averages: 2…

ChatGPT实战100例 - (06) 10倍速可视化组织架构与人员协作流程

文章目录 ChatGPT实战100例 - (06) 10倍速可视化组织架构与人员协作流程一、需求与思路二、 组织架构二、 人员协作四、 总结 ChatGPT实战100例 - (06) 10倍速可视化组织架构与人员协作流程 一、需求与思路 管理研发团队的过程中&#xff0c;组织架构与人员协作流程的可视化是…

知识变现海哥|你为什么知识却不富有,是你不懂这个道理

要有价值观念&#xff0c;要有交换思维。商业的本质都是基于价值交换&#xff0c;你能为别人提供多少价值&#xff0c;你就能赚多少米&#xff0c;你帮助别人处理的问题越多你越有价值&#xff0c;你能成就多少人你就能被多少人成就。这是商业行为的底层逻辑。 你没赚到米 一是…

初识C++之C++11

目录 一、C11的概念 二、统一的列表初始化 1.{ }初始化 2.initializer_list 三、decltype 四、lambda表达式 1. lambda表达式的出现原因 2. lambda表达式的使用 2.1 捕捉列表 2.2 参数列表 2.3 mutable 2.4 返回值类型 2.5 函数体 2.6 使用方式 3. lambda表达式…

【Python】如何用pyth做游戏脚本(太简单了吧)

文章目录 前言一、开发前景二、开发流程3.1、获取窗口句柄&#xff0c;把窗口置顶3. 2、截取游戏界面&#xff0c;分割图标&#xff0c;图片比较 二、程序核心-图标连接算法&#xff08;路径寻找&#xff09;四、开发总结五、源码总结 前言 简述&#xff1a;本文将以4399小游戏…

基于OpenCV与深度神经网络——实现证件识别扫描并1比1还原证件到A4纸上

前言 1.用拍照的证件照片正反面&#xff0c;实现用证件去复印到A4纸上的效果&#xff0c;还有证件的格式化识别。 图1&#xff1a;把拍照的证件1比1还原证件到A4纸上 图2&#xff1a;证件OCR格式化识别 2.使用Yolo做目标识别,Enet做边缘检测&#xff0c;Paddle OCR做文字识别&…

云智慧助力MLOps加速落地

背景 随着数字化和计算能力的发展&#xff0c;机器学习&#xff08;Machine Learning&#xff09;技术在提高企业生产力方面所涌现的潜力越来越被大家所重视&#xff0c;然而很多机器学习的模型及应用在实际的生产环境并未达到预期&#xff0c;大量的ML项目被证明是失败的。从…

玩机搞机----root面具的安装 更新 隐藏root 德尔塔面具等等综合解析

目前的机型都是root面具&#xff0c;今天的帖子主要分析下面具的一些使用常识。一般面具如何使用一参考我前面的帖子。基本步骤都是解锁bl---修补boot---刷入boot----安装面具apk。但目前很多app会检测系统root&#xff0c;对于有些敏感类软件例如银行等等然后会检测当前系统ro…

“码”上反馈,自动留痕:二维码助力湖塘街道人居环境巡查高效化

绍兴市柯桥区湖塘街道将农村人居环境巡查同二维码技术相结合&#xff0c;具体应用到了“村民垃圾分类检查”、“公厕卫生检查”和“各村垃圾分类工作的督导记录”这三项检查工作中&#xff0c;做到了“码”上反馈、自动留痕&#xff0c;有效提升了巡检实效&#xff0c;在2020年…

【youcans 的 OpenCV 学习课】21. Haar 小波变换与 Haar 特征检测(上)

专栏地址&#xff1a;『youcans 的图像处理学习课』 文章目录&#xff1a;『youcans 的图像处理学习课 - 总目录』 【youcans 的 OpenCV 学习课】21. Haar 小波变换与 Haar 特征检测&#xff08;上&#xff09; 1. 小波变换1.1 小波变换基本概念例程 17_1&#xff1a;常用小波族…

selenium应用之抓取b站黑马视频目录建立学习计划Excel

需求故事&#xff1a; 最近时间一下子多了起来&#xff0c;用来学习Java是最合适不过了&#xff0c;但是去b站看视频难免会没有自制力&#xff0c;于是决定用selenium来抓取b站黑马Java视频的目录创建一个学习计划的Excel&#xff0c;便于进行学习进度的管理。 注&#xff1a;纯…