【3】深度学习之Pytorch——如何使用张量处理表格数据集(葡萄酒数据集)

news/2024/4/28 19:37:06/文章来源:https://blog.csdn.net/weixin_47723732/article/details/128922814

在这里插入图片描述
张量是PyTorch中数据的基础。神经网络将张量输入并产生张量作为输出,实际上,神经网络内部和优化期间的所有操作都是张量之间的操作,而神经网络中的所有参数(例如权重和偏差)也都是张量。

怎样获取一条数据、一段视频或一段文本,并且用张量表示它们,然后用适合于训练深度学习模型的方式进行处理。这是我们需要解决的问题和学习的方向。

表格数据

机器学习工作中遇到的最简单的数据形式是位于电子表格、CSV(以逗号分隔值)文件或数据库中的。无论使用哪种介质,此数据都是一个表格,每个样本(或记录)包含一行,其中的列包含这个样本的一条信息。

首先,我们假设样本在表格中的显示顺序是没有意义的。这与时间序列不同,这里的表是独立样本的集合,而在时间序列中,样本是在时间维度上相关的。

列可以包含数值型数据(例如特定位置的温度)或标签(例如表示样品属性的字符串,比如“蓝色”)。因此,表格数据通常不是同质的(homogeneous),不同的列有不同的类型。你可能有一列显示苹果的重量,另一列则用标签编码其颜色。

然而,PyTorch张量是同质的。其他数据科学软件包,例如Pandas,具有dataframe的概念,dataframe即用异构(heterogenous)的列来表示数据的对象。相比之下,PyTorch中的信息被编码为数字,通常为浮点数(尽管也支持整数类型)。(PyTorch中的)数值编码是有意为之的,因为神经网络是将实数作为输入并通过连续应用矩阵乘法和非线性函数产生实数作为输出的数学实体。

看了上面的介绍,那么我们的第一步就是将异构的现实世界数据编码成浮点数张量以供神经网络使用。

举一个简单的例子,假设我们现在手里面有一份数据集是关于葡萄酒的数据

fixed acidity
volatile acidity
citric acid
residual sugar
chlorides
free sulfur dioxide
total sulfur dioxide
density
pH
sulphates
alcohol
quality

该文件包含用逗号分隔的值的集合,总共12列,第一行是包含列名称的标题行。前11列包含化学变量的值。最后一列包含从0(最差)到10(优秀)的感官质量得分。以下是列名在数据集中显示的顺序:针对此数据集可能的机器学习任务是通过化学表征来预测质量得分。

在这里插入图片描述
此图中,你将看到质量随着硫含量减少而提高。

Python提供了多个选项来快速加载CSV文件。三种常用的选择是

Python自带的csv模块
NumPy
Pandas

第三个选项是最省时和最省内存的方法,但是我们将避免仅仅是加载文件就将的额外的库引入学习曲线。因为我们已经介绍了NumPy,并且PyTorch具有出色的NumPy互操作性,所以将继续使用NumPy来加载文件并将生成的NumPy数组转换为PyTorch张量,如下面的代码所示。

import csv
import numpy as np
wine_path = "./winequality-white.csv"
wineq_numpy = np.loadtxt(wine_path, dtype=np.float32, delimiter=";",skiprows=1)
wineq_numpy

指定了二维数组的类型(32位浮点数)和用于分隔每一行各值的分隔符,并指出不应读取第一行,因为它包含列名。接下来,检查是否已读取所有数据

然后进一步将NumPy数组转成PyTorch张量:

import torch
wineq = torch.from_numpy(wineq_numpy)
wineq.shape, wineq.type()

输出

(torch.Size([4898, 12]), 'torch.FloatTensor')
data = wineq[:, :-1] # 除最后一列外所有列
data, data.shape

输出

(tensor([[ 7.0000,  0.2700,  0.3600,  ...,  3.0000,  0.4500,  8.8000],[ 6.3000,  0.3000,  0.3400,  ...,  3.3000,  0.4900,  9.5000],[ 8.1000,  0.2800,  0.4000,  ...,  3.2600,  0.4400, 10.1000],...,[ 6.5000,  0.2400,  0.1900,  ...,  2.9900,  0.4600,  9.4000],[ 5.5000,  0.2900,  0.3000,  ...,  3.3400,  0.3800, 12.8000],[ 6.0000,  0.2100,  0.3800,  ...,  3.2600,  0.3200, 11.8000]]),torch.Size([4898, 11]))

如果你想将target张量转换成标签张量,那么你有两个选择,具体取决于策略或使用分类数据的方式。第一种选择是将标签视为整数向量:

target = wineq[:, -1].long()
target
tensor([6, 6, 6,  ..., 6, 7, 6])

如果目标是字符串标签(例如颜色),则可以采用相同的方法为每个字符串分配一个整数。

另一种选择是构建独热(one-hot)编码,即将10个分数编码成10个向量,每个向量除了一个元素为1外其他所有元素都设置为0。此时,分数1可以映射到向量(1,0,0,0,0,0,0,0,0,0),分数5映射到(0,0,0,0,1,0,0,0,0,0),等等。分数值与非零元素的索引相对应的事实纯属偶然;你可以打乱上述分配,从分类的角度来看,什么都不会改变。

上述两种方法有明显的区别。将葡萄酒质量分数编码成分数的整数向量中会引入了分数的可排序性,在这个例子下可能是适当的,因为分数1低于分数4。这还会在分数之间产生一定的距离(例如1和3之间的距离与2和4之间的距离相同。)如果这符合你的定量关系,那就太好了。

否则,如果分数纯粹是定性的(例如颜色),则独热编码更适合,因为它不涉及隐含的顺序或距离关系。当整数之间的分数值(例如2.4)对应用没有意义时(即要么是这个值要么是那个值),独热编码才适用。

可以使用scatter_方法来实现独热编码,该方法将源张量中的值沿作为参数提供的索引进行填充。

target_onehot = torch.zeros(target.shape[0], 10)
target_onehot.scatter_(1, target.unsqueeze(1), 1.0)
tensor([[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 1., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]])

说实话非常的方便,这里提供了一个比较好的方法,快速便捷的帮助我们进行了编码(适用于分类任务)

现在看一下scatter_的作用。

首先请注意,其名称下划线结尾。PyTorch中,此约定表示该方法不会返回新的张量,而是就地修改源张量。 scatter_的参数是

  • 指定后面两个参数所处理的维度
  • 列张量,指示要填充的索引
  • 包含填充元素的张量或者单个标量(上例中即1.0)

换句话说,前面的调用可以这样理解:“对于每一行,获取目标标签的索引(在本例中即葡萄酒质量分数),并将其用作列索引以设置值为1.0。结果就是得到了一个编码分类信息的张量。

scatter_的第二个参数,即索引张量,必须具有与待填充张量相同的维数。由于target_onehot是二维(4898x10)的,因此你需要使用unsqueeze为target添加一个额外的维:

target_unsqueezed = target.unsqueeze(1)
target_unsqueezed
tensor([[6],[6],[6],...,[6],[7],[6]])

调用unsqueeze增加了一个单例的维度,从包含4898个元素的一维张量到尺寸为(4898x1)的二维张量,其内容并未改变。没有添加新元素;你决定使用额外的索引来访问元素。也就是说,你用target[0]访问target的第一个元素,并用target_unsqueezed[0,0]访问其未压缩(unsqueezed)对象的第一个元素。

PyTorch允许你在训练神经网络时直接将类别索引用作目标。但是,如果要用作网络的分类输入,则必须将其转换为独热编码张量。

首先,获取每列的均值和标准差:

data_mean = torch.mean(data, dim=0)
data_mean
data_var = torch.var(data, dim=0)
data_var

dim = 0表示沿维数0进行计算。此时,你可以通过减去平均值并除以标准偏差来对数据进行归一化,这有助于学习过程。

“dim = 0” 表示在计算中沿维数为0的方向进行运算。例如,如果您有一个矩阵,则沿维数0计算可以是对矩阵中每一行求和,并将结果作为一个向量返回。

data_normalized = (data - data_mean) / torch.sqrt(data_var)
data_normalized

使用torch.le函数确定target中哪些行对应的分数小于或等于3:

bad_indexes = torch.le(target, 3)
bad_indexes.shape, bad_indexes.dtype, bad_indexes.sum()
torch.le是小于等于的意思,例如:x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([2, 2, 2, 2])
print(torch.le(x, y))输出:
tensor([1, 1, 0, 0])torch.gt是大于的意思,例如:x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([2, 2, 2, 2])
print(torch.gt(x, y))输出:
tensor([0, 0, 1, 1])torch.lt是小于的意思,例如:x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([2, 2, 2, 2])
print(torch.lt(x, y))输出:
tensor([1, 0, 0, 0])torch.ge是大于等于的意思,例如:x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([2, 2, 2, 2])
print(torch.ge(x, y))输出:
tensor([0, 1, 1, 1])

输出

(torch.Size([4898]), torch.bool, tensor(20))

bad_indexes中只有20个元素为1!通过利用PyTorch中称为高级索引(advanced indexing)的功能,可以使用0/1张量来索引数据张量。此张量本质上将数据筛选为仅与索引张量中的1对应的元素(或行)。bad_indexes张量具有与target相同的形状,其值是0或1,具体取决于阈值与原始target张量中每个元素之间比较结果:

bad_data = data[bad_indexes]
bad_data.shape

请注意,新的bad_data张量只有20行,这与bad_indexes张量1的个数相同。另外,bad_data保留所有11列。

现在,你可以开始获取被分为好、中、坏三类的葡萄酒的信息。对每列取.mean:

bad_data = data[torch.le(target, 3)]
# 对于numpy数组和PyTorch张量,&运算符执行逻辑和运算
mid_data = data[torch.gt(target, 3) & torch.lt(target, 7)]
good_data = data[torch.ge(target, 7)]bad_mean = torch.mean(bad_data, dim=0)
mid_mean = torch.mean(mid_data, dim=0)
good_mean = torch.mean(good_data, dim=0)for i, args in enumerate(zip(col_list, bad_mean, mid_mean, good_mean)):print('{:2} {:20} {:6.2f} {:6.2f} {:6.2f}'.format(i, *args))
 0 fixed acidity          7.60   6.89   6.731 volatile acidity       0.33   0.28   0.272 citric acid            0.34   0.34   0.333 residual sugar         6.39   6.71   5.264 chlorides              0.05   0.05   0.045 free sulfur dioxide   53.33  35.42  34.556 total sulfur dioxide 170.60 141.83 125.257 density                0.99   0.99   0.998 pH                     3.19   3.18   3.229 sulphates              0.47   0.49   0.50
10 alcohol               10.34  10.26  11.42

劣质葡萄酒似乎具有更高的二氧化硫总含量(total sulfur dioxide),另外还有其他差异。你可以使用二氧化硫总含量的阈值作为区分好酒和差酒的粗略标准。现在获取二氧化硫总含量列中低于你刚刚计算的中值的索引,如下所示:

total_sulfur_threshold = 141.83
total_sulfur_data = data[:,6]
predicted_indexes = torch.lt(total_sulfur_data, total_sulfur_threshold)
predicted_indexes.shape, predicted_indexes.dtype, predicted_indexes.sum()
(torch.Size([4898]), torch.bool, tensor(2727))

上面的阈值预测略高于一半的葡萄酒是高品质的。

接下来,你需要获取(实际)优质葡萄酒的索引

actual_indexes = torch.gt(target, 5)
actual_indexes.shape, actual_indexes.dtype, actual_indexes.sum()
(torch.Size([4898]), torch.bool, tensor(3258))

由于实际的优质葡萄酒比阈值预测的多约500例,这证明该阈值并不完美。

现在,你需要查看预测与实际的吻合程度。在预测索引和实际索引之间执行逻辑与运算(请记住,每个索引都是0/1数组)得到交集,用这个交集来确定预测表现如何:

n_matches = torch.sum(actual_indexes & predicted_indexes).item()
n_predicted = torch.sum(predicted_indexes).item()
n_actual = torch.sum(actual_indexes).item()
n_matches, n_matches / n_predicted, n_matches / n_actual

这里就不具体进行接下来的预测了,主要是通过实际的案例数据知道如何将数据集转换为张量并通过具体的一些运算操作来实现这些模型所需的特点。

每文一语

当你无法改变现实的时候,那就先从自己开始改变吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_255484.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Springboot + RabbitMq 消息队列

前言 一、RabbitMq简介 1、RabbitMq场景应用,RabbitMq特点 场景应用 以订单系统为例,用户下单之后的业务逻辑可能包括:生成订单、扣减库存、使用优惠券、增加积分、通知商家用户下单、发短信通知等等。在业务发展初期这些逻辑可能放在一起…

openGL学习之GLFW和GLAD的下载和编译

背景:为什么使用GLFW和GLADOPenGL环境 目前主流的桌面平台是GLFW和GLAD之前使用的GLUT和Free GLUT已经基本淘汰了,所以记录一下如何下载GLFW和GLAD并且编译.GLFW下载:An OpenGL library | GLFW复制到你想存放的位置,我这里就存放到C盘Libaray文件夹下了,这里是我存放…

中国区注册使用ChatGPT指南(OpenAI‘s services are not available in your country)

ChatGPT又火了,各大平台热搜提到手软。暴增的访问量,即使强如ChatGPT,也表示顶不住了。Openai表示服务器已满负荷,ChatGPT暂无法提供服务由于目前ChatGPT未在中国开放,所以国内目前是无法注册使用ChatGPT。但我经过一番…

『 MySQL篇 』:MySQL表的聚合与联合查询

基础篇 MySQL系列专栏(持续更新中 …)1『 MySQL篇 』:库操作、数据类型2『 MySQL篇 』:MySQL表的CURD操作3『 MySQL篇 』:MySQL表的相关约束4『 MySQL篇 』:MySQL表的聚合与联合查询目录一. 聚合查询1.1 聚合函数1.2 GROUP BY子句…

Python将字典转换为csv

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。喜欢通过博客创作的方式对所学的知识进行总结与归纳,不仅形成深入且独到的理…

MySQL篇02-三大范式,多表查询

数据入库时,由于数据设计不合理,会存在数据重复、更新插入异常等情况, 故数据库中表的设计遵循的设计规范:三大范式1.第一范式(1NF)要求数据库的每一列都是不可分割的原子数据项,即原子性。强调的是列的原子性,即数据库中每一列的…

攀升MaxBook P2电脑U盘重装系统方法教学

攀升MaxBook P2电脑U盘重装系统方法教学。攀升MaxBook P2电脑是一款性价比非常高的笔记本。有用户购买了这款电脑后,想要将系统进行重装。今天和大家分享一个U盘重装系统的方法,学会这个方法后以后就可以自己轻松去重装电脑系统了。接下来一起看看具体的…

相机坐标系的正向投影和反向投影

1 、正向投影: 世界坐标系到像素坐标系 世界3D坐标系(x, y, z) 到图像像素坐标(u,v)的映射过程 (1)世界坐标系到相机坐标系的映射。 两个坐标系的转换比较简单,就是旋转矩阵 平移矩阵,旋转矩阵则是绕X, Y&#xff…

Thread 类及常见方法

Thread 类是 JVM 用来管理线程的一个类,换句话说,每个线程都有一个唯一的 Thread 对象与之关联。用我们上面的例子来看,每个执行流,也需要有一个对象来描述,类似下图所示,而 Thread 类的对象就是用来描述一…

分享111个JS焦点图代码,总有一款适合您

分享111个JS焦点图代码,总有一款适合您 111个JS焦点图代码下载链接:https://pan.baidu.com/s/1GxjW5m9DNOPEQd-Qf_gGSA?pwd4aci 提取码:4aci Python采集代码下载链接:https://wwgn.lanzoul.com/iKGwb0kye3wj jQuery宽屏左右…

没有人能比我快,用Python写一个自动填写答案的脚本

前言 不是标题党,真的就是没有人比我快,今天用Python写了个自动填写答案的脚本,快就算了,准确率还是百分之百 话不多说 咱先看代码 后看效果 不想看全文的 点击文末名片 领取源码 环境使用 Python 3.8Pycharm 模块使用 imp…

Request Method: OPTIONS

节选自https://blog.csdn.net/Amnesiac666/article/details/121105088版权归原作者所有,如有侵权请联系删除Request Method: OPTIONS一些接口在请求时,会自动发送一个的请求,我查了一遍代码,不是代码中写明的。 网上给出的解释涉及…

Pinecone:一款专为红队研究人员设计的WLAN网络安全审计框架

关于Pinecone Pinecone是一款专为红队研究人员设计的WLAN网络安全审计框架,该工具基于模块化开发,允许广大研究人员根据任务需求进行自定义功能扩展。Pinecone设计之初专用于树莓派,可以将树莓派打造为便携式无线网络安全审计工具&#xff0…

java面试题(十六)springBoot

1.1 说说你对Spring Boot的理解 参考答案 从本质上来说,Spring Boot就是Spring,它做了那些没有它你自己也会去做的Spring Bean配置。Spring Boot使用“习惯优于配置”的理念让你的项目快速地运行起来,使用Spring Boot很容易创建一个能独立运…

广告深度学习计算:向量召回索引的演进以及工程实现

问题定义召回操作通常作为搜索/推荐/广告的第一个阶段,主要目的是从巨大的候选集中选出相对合适的条目,交由后链路的排序等流程进行进一步的择优。因此,召回问题本质上就是一个大规模的最值的搜索问题:对于评分 和候选集 &#x…

最短路之Dijkstra(15张图解)

🌼多年后再见你 - 乔洋/周林枫 - 单曲 - 网易云音乐 闲来无事听听歌 Dijkstra可解决“单源最短路径”问题 四种最短路算法 Floyd算法 时间复杂度高,但实现容易(5行核心代码),可解决负权边,适用于数据范围…

方法的定义与使用详解

我们常用创建一个class修饰的就是一个类 其中有一个main方法,是主要的启动方法 这里写目录标题我们正常修饰的方法是由返回值的,但是用void修饰的没有static的使用方法中形参和实参的使用值传递引用传递类跟对象的关系this构造器--构造方法这个&#xf…

uniapp+java/springboot实现微信小程序APIV3支付功能

微信小程序的支付跟H5的支付和APP支付流程不一样,本次只描述下小程序支付流程。 一.账号准备 1.微信小程序账号 文档:小程序申请 小程序支付需要先认证,如果你有已认证的公众号,也可以通过公众号免费注册认证小程序。 一般30…

自定义input[type=file]上传按钮样式的四种方案,你知道几种?

目录前言方案1 opacity: 0;实现方案2 display:none样式元素选择 :label样式元素选择:其他元素::file-selector-button兼容性用法🧨🧨🧨 大家好,我是搞前端的半夏 🧑,一个热爱写文的前…

二、Linux文件 - Open函数讲解实战

目录 1.Open函数讲解 2.open函数实战 2.1 man 1 ls 查询Shell命令 2.2 man 2 open 查看系统调用函数 2.3项目实战 1.Open函数讲解 高频使用的Linux系统调用:open write read close Linux自带的工具:man手册: man 1是普通的shell…