混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

news/2024/5/14 18:55:23/文章来源:https://blog.csdn.net/qq_44886601/article/details/129952744

目录

1. Confusion Matrix

2. 其他的性能指标

3. example

4. 代码实现混淆矩阵

5.  测试,计算混淆矩阵

6. show

7. 代码


1. Confusion Matrix

混淆矩阵可以将真实标签和预测标签的结果以矩阵的形式表示出来,相比于之前计算的正确率acc更加的直观。

如下,是花分类的混淆矩阵:

之前计算的acc = 预测正确的个数 / 总个数 = 对角线的和 / 矩阵的总和

 

2. 其他的性能指标

除了准确率之外,还有别的指标可能更加方便的知道每一个类别的预测情况。

在介绍下面的内容之前,需要了解一些名词

其中,T都是True预测正确的,F都是False预测错误的。P是正确的label,N是错误的label

TP和TN都是是预测正确的类别。两者说明网络都可以正常分类,TP是真实值比如是猫,预测也是猫。TN是真实值为非猫,预测的结果也是非猫

FP和FN都是预测错误的。两者说明网络都不能正常分类,FN是说,真实值是猫,预测为非猫,FP是说真实值为非猫,预测为猫

方便的记法,T就是网络正确预测,P就是正确的类别。

例如:

TP,就是网络预测是对的,标签也是对的(猫)。

FP就是网络预测错的,标签是对的类别(也就是label是猫,网络预测是非猫,因为F代表错误的)。

FN就是,预测是错误的,N代表不是真正的标签,所以预测出来的是错误的正样本

TN就是,预测是对的,N代表不是正确的类别,所有预测出来也不是正确的类别

常见的有下面几种性能指标:除了准确率,其余的都是针对特定的类别计算的

 

3. example

比如,下面的为三分类的混淆矩阵

准确率 = 预测正确的 / 样本的总数 = (TP + TN) / (TP+TN+FP+FN) = (10+15+20)/66=0.68

下面都是针对于猫的其三个指标:

精确率 = TP / (TP+FP) = 10 / (10+1+2) = 0.77

精确度也叫查准率Precision,也就是预测为正样本中,真正正样本的比率

召回率 = TP/ (TP + FN) = 10 / (10 +3+5) = 0.56

召回率是说真正正样本中,预测为正样本的比率

特异度 = TN / (TN+FP) = (15+4+20+6) / (15+4+20+6+1+2) = 0.94

4. 代码实现混淆矩阵

首先,实现一个混淆矩阵类

 

然后更新混淆矩阵的值,传入预测和真正的标签,横坐标是真实值,纵坐标是预测值

p代表矩阵的行,也就是预测,t代表矩阵的列,就是真实

 

各项指标的计算

 

接着打印混淆矩阵

 

5.  测试,计算混淆矩阵

这里用的是之前的resnet34的迁移学习模型,数据是CIFAR10数据集

首先创建混淆矩阵类,上面注释的是手动编写的类别,下面是json文件提取的

注意这里混淆矩阵类,传入的第一个参数是混淆矩阵的size,也就是分类的个数。labels是一个list列表,存放不同的类名

 

更新打印混淆矩阵

 

6. show

混淆矩阵:

 

输出控制台:

观察可以发现召回率recall,就是对应对角线的值 / 1000

不难理解,因为recall = TP / (TP+FN),而分母就是label的个数,CIFAR10的测试集有1W张图像,共有10个类别,刚好每个是1k张图像,所有recall的分母都是1k

召回率,真正正样本中预测为正样本的个数

 将混淆矩阵输出的图关闭后,会打印性能指标

 

7. 代码

混淆矩阵放在utils中,utils代码:

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'import matplotlib.pyplot as plt
import numpy as np
from prettytable import PrettyTable# 计算混淆矩阵
class ConfusionMatrix(object):def __init__(self, num_classes: int, labels: list):self.matrix = np.zeros((num_classes, num_classes))  # 初始化混淆矩阵self.num_classes = num_classesself.labels = labelsdef update(self, preds, labels):    # 计算混淆矩阵的值for p, t in zip(preds, labels):self.matrix[p, t] += 1def summary(self):          # 计算各项指标# calculate accuracysum_TP = 0for i in range(self.num_classes):sum_TP += self.matrix[i, i]        # 对角线的和acc = sum_TP / np.sum(self.matrix)     # 混淆矩阵的和print("the model accuracy is ", acc)# precision, recall, specificitytable = PrettyTable()table.field_names = ["", "Precision", "Recall", "Specificity"]  # 表格的tittlefor i in range(self.num_classes):TP = self.matrix[i, i]                      # label为真,预测为真FP = np.sum(self.matrix[i, :]) - TP         # label为假,预测为真FN = np.sum(self.matrix[:, i]) - TP         # label为假,预测为真TN = np.sum(self.matrix) - TP - FP - FN     # label为假,预测为假Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.table.add_row([self.labels[i], Precision, Recall, Specificity])print(table)def plot(self):matrix = self.matrixprint(matrix)plt.imshow(matrix, cmap=plt.cm.Blues)plt.xticks(range(self.num_classes), self.labels, rotation=45)       # 设置x轴坐标labelplt.yticks(range(self.num_classes), self.labels)        # 设置y轴坐标labelplt.colorbar()      # 显示 colorbarplt.xlabel('True Labels')plt.ylabel('Predicted Labels')plt.title('Confusion matrix')thresh = matrix.max() / 2   # 在图中标注数量/概率信息for x in range(self.num_classes):for y in range(self.num_classes):# 注意这里的matrix[y, x]不是matrix[x, y]info = int(matrix[y, x])plt.text(x, y, info,verticalalignment='center',horizontalalignment='center',color="white" if info > thresh else "black")plt.tight_layout()plt.show()

网络model:这里是resnet的代码

import torch
import torch.nn as nn# residual block
class BasicBlock(nn.Module):expansion = 1def __init__(self,in_channel,out_channel,stride=1,downsample=None):super(BasicBlock,self).__init__()self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=3,stride=stride,padding=1,bias=False) # 第一层的话,可能会缩小size,这时候 stride = 2self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,stride=1,padding=1,bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self,x):identity = xif self.downsample is not None:     # 有下采样,意味着需要1*1进行降维,同时channel翻倍,residual block虚线部分identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return out# bottleneck
class Bottleneck(nn.Module):expansion = 4       # 卷积核的变化def __init__(self,in_channel,out_channel,stride=1,downsample=None):super(Bottleneck,self).__init__()# 1*1 降维度 --------> padding默认为 0,size不变,channel被降低self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=1,stride=1,bias=False)self.bn1 = nn.BatchNorm2d(out_channel)# 3*3 卷积self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,stride=stride,bias=False)self.bn2 = nn.BatchNorm2d(out_channel)# 1*1 还原维度 --------> padding默认为 0,size不变,channel被还原self.conv3 = nn.Conv2d(out_channel,out_channel*self.expansion,kernel_size=1,stride=1,bias=False)self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)# otherself.relu = nn.ReLU(inplace=True)self.downsample =downsampledef forward(self,x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return out# resnet
class ResNet(nn.Module):def __init__(self,block,block_num,num_classes=1000,include_top=True):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64        # max pool 之后的 depth# 网络最开始的部分,输入是RGB图像,经过卷积,图像size减半,通道变为64self.conv1 = nn.Conv2d(3,self.in_channel,kernel_size=7,stride=2,padding=3,bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)   # size减半,padding = 1self.layer1 = self.__make_layer(block,64,block_num[0])                # conv2_xself.layer2 = self.__make_layer(block,128,block_num[1],stride=2)      # conv3_xself.layer3 = self.__make_layer(block,256,block_num[2],stride=2)      # conv4_Xself.layer4 = self.__make_layer(block,512,block_num[3],stride=2)      # conv5_xif self.include_top:    # 分类部分self.avgpool = nn.AdaptiveAvgPool2d((1,1))      # out_size = 1*1self.fc = nn.Linear(512*block.expansion,num_classes)def __make_layer(self,block,channel,block_num,stride=1):downsample =Noneif stride != 1 or self.in_channel != channel*block.expansion:     # shortcut 部分,1*1 进行升维downsample=nn.Sequential(nn.Conv2d(self.in_channel,channel*block.expansion,kernel_size=1,stride=stride,bias=False),nn.BatchNorm2d(channel*block.expansion))layers =[]layers.append(block(self.in_channel, channel, downsample =downsample, stride=stride))self.in_channel = channel * block.expansionfor _ in range(1,block_num):    # residual 实线的部分layers.append(block(self.in_channel,channel))return nn.Sequential(*layers)def forward(self,x):# resnet 前面的卷积部分x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)# residual 特征提取层x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)# 分类if self.include_top:x = self.avgpool(x)x = torch.flatten(x,start_dim=1)x = self.fc(x)return x# 定义网络
def resnet34(num_classes=1000,include_top=True):return ResNet(BasicBlock,[3,4,6,3],num_classes=num_classes,include_top=include_top)def resnet101(num_classes=1000,include_top=True):return ResNet(Bottleneck,[3,4,23,3],num_classes=num_classes,include_top=include_top)

主函数main:

import torch
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34
from utils import ConfusionMatrix
import jsonif __name__ == '__main__':device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 加载数据validate_dataset = datasets.CIFAR10(root='./data',train=False,transform=data_transform)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=16, shuffle=True)# 加载网络net = resnet34(num_classes=10)model_weight_path = "./resnet.pth"net.load_state_dict(torch.load(model_weight_path, map_location=device))net.to(device)# 类别# classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# labels = [label for label in classes]# confusion = ConfusionMatrix(num_classes=10, labels=labels)# 类别json_label_path = './class_indices.json'json_file = open(json_label_path, 'r')class_indict = json.load(json_file)labels = [label for _, label in class_indict.items()]confusion = ConfusionMatrix(num_classes=10, labels=labels)net.eval()with torch.no_grad():for val_data in tqdm(validate_loader):val_images, val_labels = val_dataoutputs = net(val_images.to(device))outputs = torch.softmax(outputs, dim=1)outputs = torch.argmax(outputs, dim=1)confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())   # 更新混淆矩阵的值confusion.plot()         # 绘制混淆矩阵confusion.summary()      # 计算指标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_283201.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jenkins打包发布前端项目

1.配置前端nodejs打包环境 1.1安装nodejs插件 1.2配置jenkins nodejs环境 2.下载git插件(使用此插件配置通过gitlab标签拉取项目) 3.创建一个自由风格的发布项目 4.配置项目构建流程 4.1添加钉钉告警 4.2配置参数化构建 4.3配置源码管理为git拉取项目 4.4配置构建环境 4.5配置…

transform属性

CSS transform属性允许对某一个元素进行某些形变, 包括旋转,缩放,倾斜或平移等。 注意事项,并非所有的盒子都可以进行transform的转换,transform对于行内级非替换元素是无效的,比如对span、a元素等 常见的函数transform function有&#xf…

算法笔记:匈牙利算法

1 二部图(二分图) 二分图(Bipartite graph)是一类特殊的图,它可以被划分为两个部分,每个部分内的点互不相连。 匈牙利算法主要用来解决两个问题:求二分图的最大匹配数和最小点覆盖数。 2 最大匹…

[C++笔记]初步了解STL,string,迭代器

STL简介 STL(standard template libaray-标准模板库): 是C标准库的重要组成部分,不仅是一个可复用的组件库,而且是一个包含数据结构与算法的软件框架。 是一套功能强大的 C 模板类,提供了通用的模板类和函数,这些模板…

STM32开发(十二)STM32F103 功能应用 —— NTC 温度采集

文章目录一、基础知识点二、开发环境三、STM32CubeMX相关配置四、Vscode代码讲解(过程中相关问题点在第五点中做解释说明)五、知识点补充六、结果演示一、基础知识点 了解STM32 片内资源ADC。本实验是基于STM32F103开发 实现 NTC温度采集。 NTC温度采集…

3年外包离奇被裁,痛定思痛24K上岸字节跳动....

三年前,我刚刚从大学毕业,来到了一家外包公司工作。这份工作对于我来说是个好的起点,因为它让我接触到了真正的企业项目和实际的开发流程。但是,随着时间的流逝,我发现这份工作并没有给我带来足够的成长和挑战。 三年…

文心一言平替版ChatGLM本地部署(无需账号)!

今天用了一个超级好用的Chatgpt模型——ChatGLM,可以很方便的本地部署,而且效果嘎嘎好,经测试,效果基本可以平替内测版的文心一言。 目录 一、什么是ChatGLM? 二、本地部署 2.1 模型下载 2.2 模型部署 2.3 模型运…

数据结构系列13——排序(归并排序)

目录 1. 递归实现归并排序 1.1 思路 1.2 代码实现 1.3 时间复杂度和空间复杂度 2. 非递归实现归并排序 2.1 思路 2.2 代码实现 2.3 时间复杂度和空间复杂度 1. 递归实现归并排序 1.1 思路 将已有序的子序列合并,得到完全有序的序列;即先使每个子序列…

Excel 文件 - 比如 .csv文件编码问题 转为 UTF-8 编码 方法,解决中文乱码问题 - 解决科学计数显示问题

解决 excel 文件编码问题 1、方法一: 有点点击 excel 文件,然后选择打开方式,选择使用 Excel 2016 软件打开 选择 工具 ——> Web 选项 这里选择 UTF-8 编码 2、方法二 wps 导出为 .csv 文件,然后修改 csv 文件的后缀…

Linux修改密码报错Authentication token manipulation error的终极解决方法

文章目录报错说明解决思路流程排查特殊权限有没有上锁查看根目录和关闭selinux/etc/pam.d/passwd文件/etc/pam.d/system-auth文件终极办法,手动定义密码passwd: Have exhausted maximum number of retries for servic、ssh用普通用户登录输入密码正确但是登录时却提…

元宇宙是什么,元宇宙虚拟会议改变会议体验

随着人类社会的发展和科技的进步,传统的会议方式已经无法满足现代社会的需求。为了更好地满足社会的需求,VR全景元宇宙虚拟会议是近年来快速崛起的新兴领域,其融合了虚拟现实技术和通信技术,为人们提供了一种全新的交流、协作和学…

【探花交友】day02—完善个人信息

目录 1、完善用户信息 1.1、阿里云OSS 1.2、百度人脸识别 1.3、保存用户信息 1.4、上传用户头像 2、用户信息管理 2.1、查询用户资料 2.2、更新用户资料 3、统一token处理 3.1、代码存在的问题 3.2、解决方案 3.3、代码实现 4、统一异常处理 4.1、解决方案 4.2、…

本地部署Stable Diffusion教程,亲测可以安装成功

系列文章目录 之后补充 文章目录系列文章目录前言一、Stable Diffusion是什么?二、安装前的准备1.检查自己的电脑配置是否符合要求2.下载安装Git3.下载安装Python三、下载stable-diffusion-webui仓库四、运行webui-user.bat总结前言 近期,智能AI绘画以其…

AndroidStudio第一步安装和配置环境

AndroidStudio第一步安装和配置环境 文章目录AndroidStudio第一步安装和配置环境1.环境变量2.PATH编辑3.cmd测试版本4.android studio设置4.1 保留压缩包4.2解压缩包4.3 设置本地4.4 Dependencies5.生成apk5.15.2 需要添加才能被手机安装6.Android studio安装包和gradle下载地址…

数据仓库工具箱-第6章-订单管理

文章目录重要专业名词含义一、订单管理总线矩阵二、订单事务2.1 事实表规范化2.2 日期维度(维度角色扮演)2.2.1 角色扮演与总线矩阵2.3 产品维度2.3.1 产品维度共同特征2.3.2 维度的层次结构2.3.3 规范化与反规范化2.4 客户维度2.4.1 单一维度表与多维度…

Maven核心概念

一、Maven基础知识 Apache Maven是一个项目管理和构建工具,它基于项目对象模型(POM)的概念,通过一小段描述信息来管理项目的构建、报告和文档。 1、Maven模型 2、仓库分类 本地仓库:自己计算机上的一个目录中央仓库&a…

AR”将会成为“更加日常化的移动设备应用的一部分”吗

目录 1:AR是什么 2:AR给人类带来的贡献 3:人们在生活中可以遇到许多 AR 技术应用 4:AR 技术的未来发展的趋势: 大学主攻VR,从大一就对VR的知识,设备,已经所涉及的知识伴随我的整…

政务服务一网通办建设方案(ppt)

政务审批 – 设计思路 “互联网政务服务”平台主要由互联网政务服务门户、政务服务管理平台、业务办理系统、政务服务数据共享平台及硬件支撑平台五部分构成。平台建设以硬件支撑平台为基础,其他各平台之间的业务流、信息流通过数据共享平台实现数据互联互通。政务审…

冒泡排序(Java)

文章汇总归纳整理于:算法竞赛学习之路[Java版] 冒泡排序是交换排序中的一种所谓交换,是指根据序列中两个元素关键字的比较结果来对换这两个记录在序列中的位置。 默认排序后的数据,从小到大进行排列 冒泡排序的基本思想 从后往前&#xff08…

4年经验来面试20K的测试岗,连基础都不会,还不如招应届生

公司前段缺人,也面了不少测试,结果竟然没有一个合适的。一开始瞄准的就是中级的水准,也没指望来大牛,提供的薪资在10-20k,面试的人很多,但平均水平很让人失望。看简历很多都是4年工作经验,但面试…