语义分割混淆矩阵、 mIoU、mPA计算

news/2024/4/28 20:59:13/文章来源:https://blog.csdn.net/weixin_61235989/article/details/131724905

一、操作

需要会调试代码的人自己改,小白直接运行会出错

这是我从自己的大文件里摘取的一部分代码,可以运行,只是要改的文件地址path比较多,遇到双引号“”的地址注意一下,不然地址不对容易出错

 把 calculate.py和 utiles_metrics.py放在同一文件夹下,然后运行 calculate.py。

二、理解

test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes,weight_name)  # 执行计算mIoU的函数

gt_dir 真实标签文件夹

pred_dir 预测结果文件夹

主要是这两个变量设置,后面的可以选择性修改

image_ids 文件名称 dirList(pred_dir,path_list) saveList(path_list) 这两个函数得到

num_classes 类别数

name_classes 类别名称

weight_name 权重名称

hist为混淆矩阵,mIoU为交并比

三、代码 

 calculate.py

# -*- coding: utf-8 -*-
import torch
import osfrom time import time
# from PIL import Image
from utils_metrics import compute_mIoU
def saveList(pathName):for file_name in pathName:#f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")with open("./dataset/gt.txt", "a") as f:f.write(file_name.split(".")[0] + "\n")f.closedef dirList(gt_dir,path_list):for i in range(0, len(path_list)):path = os.path.join(gt_dir, path_list[i])if os.path.isdir(path):saveList(os.listdir(path))data_path  = './dataset/'f=open("./dataset/gt.txt", 'w')
gt_dir      = os.path.join(data_path, "real/")
pred_dir    = "./submits/log01_Dink101_five_100/test_iou/iou_60u/"
path_list = os.listdir(pred_dir)
path_list.sort()
dirList(pred_dir,path_list)
saveList(path_list)
num_classes=2
name_classes    = ["nontarget","target"]
weight_name='log01_Dink101_five_100'
image_ids   = open(os.path.join(data_path, "gt.txt"),'r').read().splitlines() test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes,weight_name)  # 执行计算mIoU的函数
print('  test_mIoU:  '+str(test_miou))

 utiles_metrics.py

from os.path import joinimport numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import os
import cv2# from matplotlib import pyplot as plt
import shutil
import numpy as np
# from matplotlib.pyplot import MultipleLocatordef f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5):n, c, h, w = inputs.size()nt, ht, wt, ct = target.size()if h != ht and w != wt:inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)temp_target = target.view(n, -1, ct)#--------------------------------------------##   计算dice系数#--------------------------------------------#temp_inputs = torch.gt(temp_inputs, threhold).float()tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])fp = torch.sum(temp_inputs                       , axis=[0,1]) - tpfn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tpscore = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)score = torch.mean(score)return score# 设标签宽W,长H
def fast_hist(a, b, n):#--------------------------------------------------------------------------------##   a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,)#--------------------------------------------------------------------------------#k = (a >= 0) & (a < n)#--------------------------------------------------------------------------------##   np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)#   返回中,写对角线上的为分类正确的像素点#--------------------------------------------------------------------------------#return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)  def per_class_iu(hist):return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) def per_class_PA(hist):return np.diag(hist) / np.maximum(hist.sum(1), 1) def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes,weight_name):  # print('Num classes', num_classes)  #-----------------------------------------##   创建一个全是0的矩阵,是一个混淆矩阵#-----------------------------------------#hist = np.zeros((num_classes, num_classes))#------------------------------------------------##   获得验证集标签路径列表,方便直接读取#   获得验证集图像分割结果路径列表,方便直接读取#------------------------------------------------#gt_imgs     = [join(gt_dir, x + ".png") for x in png_name_list]  pred_imgs   = [join(pred_dir, x + ".png") for x in png_name_list]  # building_iou=[]# background_iou=[]m_iou=[]# building_pa=[]# background_pa=[]m_pa=[]#------------------------------------------------##   读取每一个(图片-标签)对#------------------------------------------------#for ind in range(len(gt_imgs)): #------------------------------------------------##   读取一张图像分割结果,转化成numpy数组#------------------------------------------------#pred = np.array(Image.open(pred_imgs[ind]))#------------------------------------------------##   读取一张对应的标签,转化成numpy数组#------------------------------------------------#label = np.array(Image.open(gt_imgs[ind]))  # 如果图像分割结果与标签的大小不一样,这张图片就不计算if len(label.flatten()) != len(pred.flatten()):  print('Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(len(label.flatten()), len(pred.flatten()), gt_imgs[ind],pred_imgs[ind]))continue#------------------------------------------------##   对一张图片计算21×21的hist矩阵,并累加#------------------------------------------------#a=label.flatten()a//=254b=pred.flatten()b//=254hist += fast_hist(a, b,num_classes)  # # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值# mIoUs   = per_class_iu(hist)# mPA     = per_class_PA(hist)# m_iou.append(100 * np.nanmean(mIoUs[1]))# m_pa.append(100 * np.nanmean(mPA[1]))# # if ind > 0 and ind % 10 == 0:  # #     print('{:d} / {:d}: mIou-{:0.2f}; mPA-{:0.2f}'.format(ind, len(gt_imgs),# #                                             100 * np.nanmean(mIoUs[1]),# #                                             100 * np.nanmean(mPA[1])))mIoUs   = per_class_iu(hist)mPA     = per_class_PA(hist)print(mIoUs)# plt.figure()# x=np.arange(len(m_iou))# plt.plot(x,m_iou)# plt.plot(x,m_pa)# plt.grid(True)# y_major_locator=MultipleLocator(10)#把y轴的刻度间隔设置为10,并存在变量里# ax = plt.gca()# ax.yaxis.set_major_locator(y_major_locator)# ax.set_ylim(0,100)# plt.xlabel('Order')# plt.ylabel('mIOU & mPA')# plt.legend(['mIOU','mPA'],loc="upper right")# targ=os.path.join(pred_dir,os.path.pardir)# plt.savefig(os.path.join(targ, weight_name[:-3]+"_sin_miou.png"))return m_iou,m_pa,str(round(mIoUs[1] * 100, 2)),str(round(mPA[1] * 100, 2))

调试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_331184.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

干货 | 一个漏洞利用工具仓库

0x00 Awesome-Exploit 一个漏洞证明/漏洞利用工具仓库 不定期更新 部分漏洞对应POC/EXP详情可参见以下仓库&#xff1a; https://github.com/Threekiii/Awesome-POC https://github.com/Threekiii/Vulhub-Reproduce 0x01 项目导航 ActiveMQ CVE-2015-5254 Apisix CVE-2…

Unreal Engine 与 Blender - 比较指南

虚幻引擎和 Blender 是游戏开发人员中最常用的两种软件&#xff0c;它们在游戏开发工作流程中都有自己的用途和地位。作为进入视频游戏行业的初学者&#xff0c;可能很难从数十种软件中筛选找到最适合您需求的一款。希望本指南能够缓解这一困难并帮助您决定选择哪个软件。 虚幻…

前端面试题-js(三)

31 介绍js有哪些内置对象 Object 是 JavaScript 中所有对象的⽗对象数据封装类对象&#xff1a; Object 、 Array 、 Boolean 、 Number 和 String其他对象&#xff1a; Function 、 Arguments 、 Math 、 Date 、 RegExp 、 Error 32 说⼏条写JavaScript的基本规范 不要在同…

python+pytest接口自动化之测试函数、测试类/测试方法的封装

目录 前言 测试用例封装的一般规则 测试函数的封装 测试类/方法的封装 示例代码 总结 前言 在pythonpytest 接口自动化系列中&#xff0c;我们之前的文章基本都没有将代码进行封装&#xff0c;但实际编写自动化测试脚本中&#xff0c;我们都需要将测试代码进行封装&#…

【http-server】http-server的安装、前端使用http-server启动本地dist文件服务:

文章目录 一、http-server 简介:二、安装node.js:[https://nodejs.org/en](https://nodejs.org/en)三、安装http-server:[https://www.npmjs.com/package/http-server](https://www.npmjs.com/package/http-server)四、开启服务&#xff1a;五、http-server参数&#xff1a;【1…

记录stm32c8t6使用TIM4_CH1、TIM4_CH2输出PWM波控制编码电机出现的问题

由于之前是使用PB9、PB7引脚即TIM4_ch3\TIM4_ch4&#xff0c;由于项目更改为c8t6的PB、PB7引脚&#xff08;TIM4_ch3\TIM4_ch4&#xff09; 改为配置后发现只有一边的轮子可以转到&#xff0c;明明配置没什么问题&#xff0c;编译也没有报错&#xff0c;最后将pwm的调制模式更改…

抖音seo源码搭建---PHP,vue jquery layui

抖音seo源码&#xff0c;抖音seo矩阵系统源码技术搭建&#xff0c;抖音seo源码技术开发思路梳理搭建 开发思路&#xff1a;抖音seo系统&#xff0c;抖音seo矩阵系统底层框架上支持了ai视频混剪&#xff0c;视频产出&#xff0c;视频AI制作&#xff0c;多账号多平台矩阵&#x…

Java正则表达式MatchResult的接口、Pattern类、Matcher类

Java正则表达式MatchResult的接口 java.util.regex.MatchResult接口表示匹配操作的结果。 此接口包含用于确定与正则表达式匹配的结果的查询方法。可以看到匹配边界&#xff0c;组和组边界&#xff0c;但不能通过MatchResult进行修改。 接口声明 以下是java.util.regex.Matc…

3D开发工具HOOPS 2023 SP2更新:增加了SOLIDWORKS贴花支持!

HOOPS SDK是全球领先开发商TechSoft 3D旗下的原生产品&#xff0c;专注于Web端、桌面端、移动端3D工程应用程序的开发。长期以来&#xff0c;HOOPS通过卓越的3D技术&#xff0c;帮助全球600多家知名客户推动3D软件创新&#xff0c;这些客户包括SolidWorks、SIEMENS、Oracle、Ar…

Transaction事务使用了解

1.功能概述 ​ 在wiki的解释中&#xff0c;事务是一组单元化的操作&#xff0c;这组操作可以保证要么全部成功&#xff0c;要么全部失败&#xff08;只要有一个失败的操作&#xff0c;就会把其他已经成功的操作回滚&#xff09;。 ​ 这样的解释还是不够直观&#xff0c;看下…

本地appserv外挂网址如何让外网访问?快解析端口映射

一、appserv是什么&#xff1f; AppServ 是 PHP 网页架站工具组合包&#xff0c;作者将一些网络上免费的架站资源重新包装成单一的安装程序&#xff0c;以方便初学者快速完成架站&#xff0c;AppServ 所包含的软件有&#xff1a;Apache[、Apache Monitor、PHP、MySQL、phpMyAdm…

SOEM_1(笔记,从别的博客文章学的笔记)

目录介绍&#xff1a; doc&#xff1a;帮助文档、 osal&#xff1a;主要是用于符合OSADL和实时进程创建。也就是说&#xff1a;发送EtherCAT数据包不能抖动太大&#xff0c;如果直接使用linux提供的原生线程&#xff0c;可能实时性无法满足。需要对Linux内核打上实时补丁&…

ELK-日志服务【kafka-配置使用】

kafka-01 10.0.0.21 kafka-02 10.0.0.22 kafka-03 10.0.0.23 【1】安装zk集群、配置 [rootes-01 ~]# yum -y install java maven [rootes-01 ~]# tar xf apache-zookeeper-3.5.9-bin.tar.gz -C /opt/[rootes-01 ~]# cd /opt/apache-zookeeper-3.5.9-bin/conf/ [rootes-…

复习第五课 C语言-初识数组

目录 【1】初识数组 【2】一维数组 【3】清零函数 【4】字符数组 【5】计算字符串实际长度 练习&#xff1a; 【1】初识数组 1. 概念&#xff1a;具有一定顺序的若干变量的集合 2. 定义格式&#xff1a; 数组名 &#xff1a;代表数组的首地址&#xff0c;地址常量&…

字符函数和内存函数(二)

目录 一、strtok函数 二、strerror函数 三、memcpy函数 3.1memcpy函数的认识 3.2memcpy函数的模拟实现 四、memmove函数 4.1memmove函数的认识 4.2memmove函数的模拟实现 五、memcmp函数 5.1memcmp函数的认识 5.2memcmp函数的模拟实现 六、memset函数 七、字符分类函…

PyCharm 自动添加作者信息、创建时间等信息

PyCharm 自动添加作者信息、创建时间等信息‘ 第一步 找到settings 第二步&#xff0c;找到下图所示位置输入下面代码&#xff0c;作者改成你自己的缩写&#xff0c;你也可以添加其他的 Project &#xff1a;${PROJECT_NAME} File &#xff1a;${NAME}.py IDE &…

【技能实训】DMS数据挖掘项目-Day09

文章目录 任务9【任务9.1.1】升级DataBase类为可序列化的类&#xff0c;以便在文件保存或网络中传递【任务9.1.2】升级LogRec类为可序列化的类&#xff0c;以便在文件保存或网络中传递【任务9.1.3】升级MatchedLogRec类为可序列化的类&#xff0c;以便在文件保存或网络中传递【…

StringBuffer类 StringBuilder 类

StringBuffer类 介绍 StringBuffer是一个容器&#xff0c;代表可变的字符序列&#xff0c;可以对字符串内容进行增删。 StringBuffer是可变长度的。 实现了序列化接口&#xff0c;可实现串行化&#xff08;可以将内容保存至文件或者网络传输&#xff09;&#xff1a; Serial…

关于Java的网络编程

网络的一些了解 网络通信协议 链路层&#xff1a;链路层是用于定义物理传输通道&#xff0c;通常是对某些网络连接设备的驱动协议&#xff0c;例如针对光纤、网线提供的驱动。网络层&#xff1a;网络层是整个TCP/IP协议的核心&#xff0c;它主要用于将传输的数据进行分组&…

华为申请注册盘古大模型商标;京东推出言犀大模型,率先布局产业应用

7月14日科技新闻早知道&#xff0c;一分钟速览。 1.华为申请注册盘古大模型商标&#xff1a; 据天眼查 App 显示&#xff0c;7 月 7 日&#xff0c;华为技术有限公司申请注册“华为云盘古”、“Huawei Cloud Pangu Models”文字及图形商标&#xff0c;国际分类为网站服务、社…