Pytorch深度学习笔记(十)多分类问题

news/2024/4/19 7:09:43/文章来源:https://blog.csdn.net/qq_45981086/article/details/130333116

课程推荐:09.多分类问题_哔哩哔哩_bilibili 

目录

1. 多分类模型

2. softmax函数模型

3. Loss损失函数

4.实战MNIST Dataset


之前,在逻辑斯蒂回归中我们提到了二分类任务,现在我们讨论多分类问题。

1. 多分类模型

与二分类不同的是多分类有多个输出概率,由softmax层完成。

softmax用于多分类过程中,它将多个神经元的输出,映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类。

softmax层的前一层是线性层,也就是说softmax之前的一层不需要再做Sigmoid,Sigmoid在输入softmax之前已经做过了。线性层输出的值是一般值,还不是概率值。经过sotfmax层之后才能变成概率值

2. softmax函数模型

softmax函数的作用:softmax两个作用,如果在进行softmax前的input有负数,通过指数变换,得到正数。所有分类的概率求和为1。

softmax函数模型: p(y=i)=\frac{e^{z_{i}}}{\sum_{j=0}^{K-1}e^{z_{i}}},i\in\{0......i\}

softmax实现过程:Exponent为求指数,sum为求和,Divide为求除数。最后得到的概率值为\hat{y}

代码实现: 

import numpy as np
y = np.array([1, 0, 0])
z = np.array([0.2, 0.1, -0.1])
y_pred = np.exp(z) / np.exp(z).sum()
loss = (-y * np.log(y_pred)).sum()
print(loss)

3. Loss损失函数

Loss损失函数计算公式:Loss(\hat{Y},Y)=-Ylog\hat{Y}

交叉熵损失CrossEntropyLoss <==> LogSoftmax + NLLLoss。LogSoftmax用于得到预测概率值\hat{y},NLLLoss用于 \hat{y}和y损失值计算。NLLLoss可单独使用,可以思考一下CrossEntropyLoss与NLLLoss区别,方便日后灵活应用。

代码实现:

import torch
# 长整形的张量,LongTensor[0]是指索引为0 的(就是第一个元素)为1,其余为0
y = torch.LongTensor([0])
z = torch.Tensor([[0.2, 0.1, -0.1]])
# 损失函数
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(z, y)
print(loss)

4.实战MNIST Dataset

将MNIST Dataset中的手写图像映射到一个28*28的矩阵中,颜色越深数值越大,矩阵中的数值在[0,1]之间。把像素值0-255转化为图像张量0-1

将PIL图像转化为张量:

transforms.ToTensor()将PIL图像由单通道转变为多通道再取值[1,0]之间,单变多28*28 =>1*28*28,W*H => C*W*H , C:channel通道,W:宽,H:高,C*W*H。

标椎化处理:

Normalize计算公式

Normalize函数将转化后的张量映射到[0,1]之间

# 把像素值0-255转化为图像张量0-1
transform = transforms.Compose([# transforms.ToTensor()转化张量,Normalize映射到[0,1]之间transforms.ToTensor(),# (均值,标准差)transforms.Normalize((0.1307, ), (0.381, ))
])

完整代码:

这是一种全连接的神经网络。

1.把像素值0-255转化为图像张量0-1

2.准备数据:view()函数,改变张量形状,参数为-1,根据后一个数,自动调整张量的形状和大小。 

3.模型结构:线性层(降维) => 激活层(relu激活) => …… => 线性层(维度为10)

view()参考博客

老四步:

1.数据准备

2.设计模型

3.构造损失函数和优化器

4.训练周期(前馈—>反馈—>更新)

import torch
# 用于图像映射到矩阵中
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optimbatch_size = 64# 把像素值0-255转化为图像张量0-1
transform = transforms.Compose([# transforms.ToTensor()转化张量,Normalize映射到[0,1]之间transforms.ToTensor(),# (均值,标准差)transforms.Normalize((0.1307, ), (0.381, ))
])# 训练集
train_dataset = datasets.MNIST(root="../dataset/mnist",train=True,download=True,transform=transform)train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)test_dataset = datasets.MNIST(root="../dataset/mnist",train=False,download=True,transform=transform)test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False)#…2.设计模型………………………………………………………………………………………………………………………………………#
# 继承torch.nn.Module,定义自己的计算模块,neural network
class Net(torch.nn.Module):# 构造函数def __init__(self):# 调用父类构造super(Net, self).__init__()# 从784维降到10维self.l1 = torch.nn.Linear(784, 512)self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)# 前馈函数def forward(self, x):# 改变张量形状,784表示确定的列数,自动调整行数x = x.view(-1, 784)# 激活x = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))# 最后传入一个线性层return self.l5(x)#……3.构造损失函数和优化器………………………………………………………………………………………………………#
model = Net()
# 实例化损失函数,返回损失值
criterion = torch.nn.CrossEntropyLoss()
# 优化器,momentum冲量
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)#……4.训练和测试……………………………………………………………………………………………………………………………#
def train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader, 0):# 1.准备数据inputs, labels = dataoptimizer.zero_grad()# 2.正向传播outputs = model(inputs)loss = criterion(outputs, labels)# 3.反向传播loss.backward()# 4.更新权重woptimizer.step()# 损失求和running_loss += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))running_loss = 0.0def test():correct = 0total = 0# with torch.no.grad():内部代码不会再计算梯度with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)# dim沿着第一个纬度(行)找最大值,返回(最大值,最大值下标)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)# 预测值与标签对比,正确则相加correct += (predicted == labels).sum().item()# 输出精确率print('Accuracy on test set: %d %%' % (100 * correct / total))if __name__ == '__main__':for epoch in range(10):train(epoch)test()

_, predicted = torch.max(outputs.data, 1):返回与样本对比后,相似度最大的样本下标

_, predicted = torch.max(outputs.data, 1)的理解

训练结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_102880.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Hive基础和使用详解

文章目录 一、启动hive1. hive启动的前置条件2. 启动方式一: hive命令3. 方式二:使用jdbc连接hive 二、Hive常用交互命令1. hive -help 命令2. hive -e 命令3. hive -f 命令4. 退出hive窗口5. 在hive窗口中执行dfs -ls /&#xff1b; 三、Hive语法1.DDL语句1.1 创建数据库1.2 两…

Redis 数据存储原理

核心模块如图 1-10。 图1-10 图 1-10 Client 客户端&#xff0c;官方提供了 C 语言开发的客户端&#xff0c;可以发送命令&#xff0c;性能分析和测试等。 网络层事件驱动模型&#xff0c;基于 I/O 多路复用&#xff0c;封装了一个短小精悍的高性能 ae 库&#xff0c;全称是 …

【人工智能】遗传算法

人工智能算法---遗传算法&#xff08;基础篇&#xff09; 知识导图&#xff1a;遗传算法&#xff08;概念&#xff09;1.初始化种群二进制编码与解码 2.选择操作3.交叉操作4.评估操作5.终止操作 知识导图&#xff1a; 遗传算法&#xff08;概念&#xff09; 可以把遗传算法类比…

Docker 快速入门

1、Docker 简介 Docker是一个开源的容器引擎&#xff0c;它可以帮助我们更快地交付应用。Docker可将应用程序和基础设施层隔离&#xff0c;并且能将基础设施当作程序一样进行管理。使用Docker&#xff0c;可更快地打包、测试以及部署应用程序&#xff0c;并可减少从编写到部署…

python的智能换行函数(一堆烦乱的判断)

def zntxt(txt):line30 #设置单行长度js,e,s,rs,aa,nm,x,y{},[],txt,[],,[],0,0n 1 if ord(s[0]) > 127 else 0for i in range(len(s)):m1 if ord(s[i]) > 127 else 0if m!n:rs.append(aa)aas[i]elif ilen(s)-1:aas[i]rs.append(aa)else:aas[i]nmfor i in rs: for j in…

搜索引擎找外贸客户

说起搜索引擎&#xff0c;我们每个人都不陌生&#xff0c;也许第一时间就能想到平日经常使用的“百度一下”和凭借强大算法及丰富功能占据近85%市场份额的谷歌搜索&#xff08;Statista 2023年1月数据&#xff09;这些耳熟能详的搜索引擎。对于外贸人而言搜索引擎也是非常实用的…

一文谈谈文心一言对比ChatGPT4.0的差距

对于想体验文心一言的朋友&#xff0c;可以进行申请尝试&#xff0c;快速入口 如果想体验ChatGPT的朋友&#xff0c;可以自行fq注册&#xff1b;但是由于现在限制注册并且不稳定&#xff0c;对于不会用梯子不想注册的朋友可以使用这个进行访问&#xff0c;快速入口 关于ChatG…

PMP证书备考攻略+PMP知识点汇总

一&#xff0c;考PMP好处多 1.能力提升 大型项目&#xff0c;领导专业团队 2.升职加薪 晋升管理岗&#xff0c;优先升职加薪 3.招投标加分 具有PMP证书&#xff0c;企业招标有加分 4.转型利器 助力转型&#xff0c;拓宽职业发展 5.公司支持 企业鼓励学习&#xff0c;报销费用 6…

C++模板使用

感谢你的阅读&#xff01;&#xff01;&#xff01; 目录 感谢你的阅读&#xff01;&#xff01;&#xff01; 举个例子&#xff1a; template 有什么意义为什么要用模板 与typedef的区别 使用方法 模板&#xff1a;隐式实例化与显示实例化 和非模板函数以及多个模板类…

气传导耳机和骨传导耳机的区别是啥?气传导耳机有哪些优缺点?

本文主要讲解一下气传导耳机和骨传导耳机的区别、气传导耳机的优缺点&#xff0c;并推荐一些目前主流的气传导耳机款式&#xff0c;大家可以根据自身需求&#xff0c;选择自己感兴趣的部分观看。 气传导耳机和骨传导耳机不同点&#xff1a; 气传导耳机和骨传导耳机最大且最根…

什么是 MVVM?MVVM和 MVC 有什么区别?什么又是 MVP ?

目录标题 一、什么是MVVM&#xff1f;二、MVC是什么&#xff1f;三、MVVM和MVC的区别&#xff1f;四、什么是MVP&#xff1f; 一、什么是MVVM&#xff1f; MVVM是 Model-View-ViewModel的缩写&#xff0c;即模型-视图-视图模型。MVVM 是一种设计思想。 模型&#xff08;Model…

windows安装sqli-labs靶场,两种方式

1、安装phpstudy 官网打不开了&#xff0c;下载地址在这儿https://download.csdn.net/download/weixin_59679023/87711536 双击安装 点自定义安装&#xff0c;选择安装目录&#xff0c;注意目录不要有空格和中文 安装完成启动红框内的两个服务 2、安装sqli靶场 这个包支持ph…

信息收集(三)端口和目录信息收集

信息收集&#xff08;一&#xff09;域名信息收集 信息收集&#xff08;二&#xff09;IP信息收集 端口是什么 "端口"是英文port的意译&#xff0c;可以认为是设备与外界通讯交流的出口。端口可分为虚拟端口和物理端口&#xff0c;其中虚拟端口指计算机内部或交换机…

关于package.json中版本锁定的方法和问题解决

前置知识&#xff1a;先了解一下package.json和package-lock.json的关系和区别&#xff0c;请看这篇文章 然后我们来说一下改怎么锁定版本&#xff1f; 首先肯定是要把package.json中的 ^ 这个符号去掉&#xff0c;但是如果你只去掉package.json中的 ^那就太天真了&#xff0…

必应,百度,神马头条,搜狗专用站长seo推送工具大全

软件介绍&#xff1a; 百度开始打击滥用api问题&#xff0c;针对这个问题已经开发了拟人推送系列功能&#xff0c;放心使用。 五合一高效推送软件&#xff0c;目前支持百度&#xff0c;神马&#xff0c;必应&#xff0c;搜狗&#xff0c;头条&#xff0c;谷歌六大搜索引擎同步…

优秀简历的HR视角:怎样打造一份称心如意的简历?

简历的排版应该简洁工整&#xff0c;注重细节。需要注意对齐和标点符号的使用&#xff0c;因为在排版上的细节需要下很大功夫。除此之外&#xff0c;下面重点讲述几点简历内容需要注意的地方。 要点1&#xff1a;不相关的不要写。 尤其是与应聘岗位毫不相关的实习经历&#x…

服务提供者 Eureka + 服务消费者(Rest + Ribbon)实战

1、Ribbon背景介绍 Ribbon是Netflix发布的开源项目&#xff0c;主要功能是提供客户端的软件负载均衡算法&#xff0c;将Netflix的中间层服务连接在一起。Ribbon客户端组件提供一系列完善的配置项如连接超时&#xff0c;重试等。简单来说&#xff0c;就是在配置文件中列出Load B…

【手把手做ROS2机器人系统开发二】熟悉ROS2基本命令

【手把手做ROS2机器人系统开发二】熟悉ROS2基本命令 一、上讲回顾 在上一讲开发环境搭建中&#xff0c;我们讲解了如何搭建Ubuntu系统环境和ROS2开发运行环境。 1.Ubuntu系统安装 2.ROS2系统环境安装 二、ROS2核心命令讲解 1、daemon-各种守护进程相关的子命令 查看帮助&am…

【计算机网络】网络命令的使用

文章目录 一、实验目的二、实验工具三、实验要求四、实验过程01 ping 命令的使用应用1&#xff1a;验证本地计算机上是否正确安装了 TCP/IP 协议应用2&#xff1a;测试某个目的主机可达性应用3&#xff1a;键入 ping&#xff0c;查看 ping 的其他参数含义 02 netstat 命令的典型…

可能是最强的Python可视化神器,建议一试

数据分析离不开数据可视化&#xff0c;我们最常用的就是Pandas&#xff0c;Matplotlib&#xff0c;Pyecharts当然还有Tableau&#xff0c;看到一篇文章介绍Plotly制图后我也跃跃欲试&#xff0c;查看了相关资料开始尝试用它制图。 Plotly Plotly是一款用来做数据分析和可视化的…