unet测试评估metric脚本

news/2024/3/29 22:37:56/文章来源:https://blog.csdn.net/isyoungboy/article/details/130008153

全部复制的paddleseg的代码转torch

import argparse
import logging
import osimport numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transformsfrom utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask
from torch.utils.data import DataLoader, random_split
from utils.data_loading import BasicDataset, CarvanaDataset
from tqdm import tqdm
import torch.nn.functional as F# 使用python写一个评估使用pytorch训练的unet模型的好坏,模型输出nchw格式的数据,真实标签数据为nhw格式,请计算模型的accuracy, calss precision ,class recall,kappa指标EPSILON = 1e-32def calculate_area(pred, label, num_classes, ignore_index=255):"""Calculate intersect, prediction and label areaArgs:pred (Tensor): The prediction by model.label (Tensor): The ground truth of image.num_classes (int): The unique number of target classes.ignore_index (int): Specifies a target value that is ignored. Default: 255.Returns:Tensor: The intersection area of prediction and the ground on all class.Tensor: The prediction area on all class.Tensor: The ground truth area on all class"""if len(pred.shape) == 4:pred = torch.squeeze(pred, axis=1)if len(label.shape) == 4:label = torch.squeeze(label, axis=1)if not pred.shape == label.shape:raise ValueError('Shape of `pred` and `label should be equal, ''but there are {} and {}.'.format(pred.shape,label.shape))pred_area = []label_area = []intersect_area = []mask = label != ignore_indexfor i in range(num_classes):pred_i = torch.logical_and(pred == i, mask)label_i = label == iintersect_i = torch.logical_and(pred_i, label_i)pred_area.append(torch.sum(pred_i))  label_area.append(torch.sum(label_i))  intersect_area.append(torch.sum(intersect_i))  pred_area = torch.stack(pred_area)  label_area = torch.stack(label_area)  intersect_area = torch.stack(intersect_area)  return intersect_area, pred_area, label_areadef get_args():parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,help='Percent of the data that is used as validation (0-100)')parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')parser.add_argument('--root', '-r', type=str, default=False, help='root dir')parser.add_argument('--num', '-n', type=int, default=False, help='num of classes')return parser.parse_args()dir_img_path = 'imgs'
dir_mask_path = 'masks'import metricsdef train_net(net,device,epochs: int = 5,batch_size: int = 1,learning_rate: float = 0.001,val_percent: float = 0.1,save_checkpoint: bool = True,img_scale: float = 0.5,amp: bool = False,root_dir: str = '/data/yangbo/unet/datas/data1'):train_dir_img=os.path.join(root_dir,'train',dir_img_path)train_dir_mask=os.path.join(root_dir,'train',dir_mask_path)val_dir_img=os.path.join(root_dir,'val',dir_img_path)val_dir_mask=os.path.join(root_dir,'val',dir_mask_path)# 1. Create datasettry:train_dataset = CarvanaDataset(train_dir_img, train_dir_mask, img_scale)val_dataset = CarvanaDataset(val_dir_img, val_dir_mask, img_scale)except (AssertionError, RuntimeError):train_dataset = BasicDataset(train_dir_img, train_dir_mask, img_scale)val_dataset = BasicDataset(val_dir_img, val_dir_mask, img_scale)n_val = len(val_dataset)n_train = len(train_dataset)# 3. Create data loadersloader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, **loader_args)# (Initialize logging)logging.info(f'''Starting training:Epochs:          {epochs}Batch size:      {batch_size}Learning rate:   {learning_rate}Training size:   {n_train}Validation size: {n_val}Checkpoints:     {save_checkpoint}Device:          {device.type}Images scaling:  {img_scale}Mixed Precision: {amp}''')# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP#optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score# 5. Begin trainingintersect_area_all=torch.zeros([1])pred_area_all=torch.zeros([1])label_area_all=torch.zeros([1])for idx,batch in tqdm(enumerate(val_loader)):images = batch['image']true_masks = batch['mask']assert images.shape[1] == net.n_channels, \f'Network has been defined with {net.n_channels} input channels, ' \f'but loaded images have {images.shape[1]} channels. Please check that ' \'the images are loaded correctly.'images = images.to(device=device, dtype=torch.float32)true_masks = true_masks.to(device=device, dtype=torch.long)with torch.no_grad():masks_pred = net(images)masks_pred=torch.argmax(masks_pred,axis=1,keepdim=True)intersect_area, pred_area, label_area=calculate_area(masks_pred,true_masks,3)intersect_area_all = intersect_area_all + intersect_areapred_area_all = pred_area_all + pred_arealabel_area_all = label_area_all + label_areametrics_input = (intersect_area_all, pred_area_all, label_area_all)class_iou, miou = metrics.mean_iou(*metrics_input)acc, class_precision, class_recall = metrics.class_measurement(*metrics_input)kappa = metrics.kappa(*metrics_input)class_dice, mdice = metrics.dice(*metrics_input)infor="[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format(len(val_loader), miou, acc, kappa, mdice)print(infor)print("[EVAL] Class IoU: " + str(np.round(class_iou, 4)))print("[EVAL] Class Precision: " + str(np.round(class_precision, 4)))print("[EVAL] Class Recall: " + str(np.round(class_recall, 4)))if __name__ == '__main__':args = get_args()logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device {device}')# Change here to adapt to your data# n_channels=3 for RGB images# n_classes is the number of probabilities you want to get per pixel# 修改numclassnet = UNet(n_channels=3, n_classes=args.num, bilinear=True)net.eval()logging.info(f'Network:\n'f'\t{net.n_channels} input channels\n'f'\t{net.n_classes} output channels (classes)\n'f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')if args.load:net.load_state_dict(torch.load(args.load, map_location=device))logging.info(f'Model loaded from {args.load}')net.to(device=device)try:train_net(net=net,epochs=0,batch_size=args.batch_size,learning_rate=0,device=device,img_scale=args.scale,val_percent=args.val / 100,amp=args.amp,root_dir=args.root)except KeyboardInterrupt:torch.save(net.state_dict(), 'INTERRUPTED.pth')logging.info('Saved interrupt')

metris.py

import numpy as np
import torch
import sklearn.metrics as skmetricsdef mean_iou(intersect_area, pred_area, label_area):"""Calculate iou.Args:intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.pred_area (Tensor): The prediction area on all classes.label_area (Tensor): The ground truth area on all classes.Returns:np.ndarray: iou on all classes.float: mean iou of all classes."""intersect_area = intersect_area.numpy()pred_area = pred_area.numpy()label_area = label_area.numpy()union = pred_area + label_area - intersect_areaclass_iou = []for i in range(len(intersect_area)):if union[i] == 0:iou = 0else:iou = intersect_area[i] / union[i]class_iou.append(iou)miou = np.mean(class_iou)return np.array(class_iou), mioudef class_measurement(intersect_area, pred_area, label_area):"""Calculate accuracy, calss precision and class recall.Args:intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.pred_area (Tensor): The prediction area on all classes.label_area (Tensor): The ground truth area on all classes.Returns:float: The mean accuracy.np.ndarray: The precision of all classes.np.ndarray: The recall of all classes."""intersect_area = intersect_area.numpy()pred_area = pred_area.numpy()label_area = label_area.numpy()mean_acc = np.sum(intersect_area) / np.sum(pred_area)class_precision = []class_recall = []for i in range(len(intersect_area)):precision = 0 if pred_area[i] == 0 \else intersect_area[i] / pred_area[i]recall = 0 if label_area[i] == 0 \else intersect_area[i] / label_area[i]class_precision.append(precision)class_recall.append(recall)return mean_acc, np.array(class_precision), np.array(class_recall)def kappa(intersect_area, pred_area, label_area):"""Calculate kappa coefficientArgs:intersect_area (Tensor): The intersection area of prediction and ground truth on all classes..pred_area (Tensor): The prediction area on all classes.label_area (Tensor): The ground truth area on all classes.Returns:float: kappa coefficient."""intersect_area = intersect_area.numpy().astype(np.float64)pred_area = pred_area.numpy().astype(np.float64)label_area = label_area.numpy().astype(np.float64)total_area = np.sum(label_area)po = np.sum(intersect_area) / total_areape = np.sum(pred_area * label_area) / (total_area * total_area)kappa = (po - pe) / (1 - pe)return kappadef dice(intersect_area, pred_area, label_area):"""Calculate DICE.Args:intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.pred_area (Tensor): The prediction area on all classes.label_area (Tensor): The ground truth area on all classes.Returns:np.ndarray: DICE on all classes.float: mean DICE of all classes."""intersect_area = intersect_area.numpy()pred_area = pred_area.numpy()label_area = label_area.numpy()union = pred_area + label_areaclass_dice = []for i in range(len(intersect_area)):if union[i] == 0:dice = 0else:dice = (2 * intersect_area[i]) / union[i]class_dice.append(dice)mdice = np.mean(class_dice)return np.array(class_dice), mdice

使用示例

python .\test2.py --root D:\pic\23\0403\851-1003339-H01\bend --scale 0.25 --load C:\Users\Admin\Desktop\fsdownload\checkpoint_epoch485.pth --num 3

结果展示

[EVAL] #Images: 74 mIoU: 0.5119 Acc: 0.9996 Kappa: 0.4405 Dice: 0.6002
[EVAL] Class IoU: [0.9997 0.4177 0.1183]
[EVAL] Class Precision: [0.9998 0.6767 0.1858]
[EVAL] Class Recall: [0.9998 0.5219 0.2456]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_286212.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】-计数排序

🎇作者:小树苗渴望变成参天大树 🎉 作者宣言:认真写好每一篇博客 🎊作者gitee:link 如 果 你 喜 欢 作 者 的 文 章 ,就 给 作 者 点 点 关 注 吧! 文章目录前言一、计数排序二、排序算法复杂度…

《花雕学AI》18:AI绘画尝鲜Prompt Hunt,使用人工智能模型来创造、探索和分享艺术作品

引言: 人工智能是当今科技领域的热门话题,它不仅可以帮助人类解决各种实际问题,也可以激发人类的创造力和艺术感。Prompt Hunt就是一个利用人工智能模型来创造、探索和分享艺术作品的AI绘画网站。它提供了三种不同的模型,分别是S…

Java垃圾收集原理

程序计数器、虚拟机栈、本地方法栈这三个区域随线程而灭,栈中栈帧的内存大小也是在确定的。这几个区域的内存分配和回收都具有确定性,因此不需要过多考虑如何回收。 Java堆和方法区这两个区域有着很显著的不确定性 一个接口的实现类需要的内存可能不一…

用Flutter开发一款音乐App(从0到1开发一款音乐App)

Flutter Music_Listener(flutter音乐播放器) Flutter version 3.9 项目介绍 1、项目整体基于getxretrofitdiojsonserialize开发 2、封装通用控制器BaseController,类似jetpack mvvm框架中的BaseViemodel 3、封装基础无状态基类BaseStatelessWidget,结合…

十三、市场活动:全部导出

功能需求:批量导出市场活动 用户在市场活动主页面,点击"批量导出"按钮,把所有市场活动生成一个excel文件,弹出文件下载的对话框; 用户选择要保存的目录,完成导出市场活动的功能. *导出成功之后,页面不刷新 功能分析:导出市场活动 1.给批量…

Vue组件化编程【Vue】

2.Vue 组件化编程 2.1 模块与组件、模块化与组件化 2.1.1 模块 理解:向外提供特定功能的js程序,一般就是一个js文件为什么:js文件很多很复杂作用:复用js、简化js的编写、提高js运行效率。 2.1.2 组件 理解:用来实…

接口自动化【一】(抓取后台登录接口+postman请求通过+requests请求通过+json字典区别)

文章目录 前言一、requests库的使用二、json和字典的区别三、后端登录接口-请求数据生成四、接口自动化-对应电商项目中的功能五、来自postman的代码-后端登录总结前言 记录:json和字典的区别,json和字段的相互转化;postman发送请求与Python…

source insight4.0使用技巧总结

一、技巧1:查看函数调用关系 步骤 1:在主菜单中点击下图中的按钮 图 1 打开relation界面 步骤 2:在弹出的relation界面点击“设置”按钮, 图2 点击“设置”按钮 步骤3: 在“设置”界面中,“Levels”选择…

AC7811-FOC无感控制代码详解

目录 矢量控制原理 矢量控制框图 电流采样方式 电流在整个控制过程中的传递 采样关键点 三电阻 双电阻 单电阻 三者对比 坐标变换 dq轴电流的PI控制 启动方式 启动波形 脉冲注入 高频注入 Startup 预定位到指定角度 PulseInject_api hfi_api Speed loop s…

前端学习:HTML块、类、Id

目录 快 一、块元素、内联元素 二、HTML 元素 三、HTML元素 类 一、分类块级元素 二、分类行内元素 Id 一、使用 id 属性 二、 class与ID的差异 三、总结 快 一、块元素、内联元素 大多数HTML元素被定义为块级元素或内联元素。 块级元素在浏览器显示时,通常会…

FTP-----局域网内部传输文件(1)

在日常工作中,如果需要跨设备的传输文件,您需要借助USB数据线或者借助应用实现无线互联,将所需文件传输到对应设备,这一来一去,花费的时间与精力变多了,那么,怎么实现不使用第三方软件来实现跨设…

3-5年以上的功能测试如何进阶自动化?【附学习路线】

做为功能测试人员来讲,从发展方向上可分两个方面: 1、业务流程方向 2、专业技能方向。 当确定好方向后,接下来就是如何达到了。(文末自动化测试学习资料分享) 一、业务流程方向 1、熟悉底层的业务 作为功能测试工程师来讲,了解…

【C++高级】手写线程池项目-经典死锁问题分析-简历项目输出指导

作为五大池之一, 线程池的应用非常广 泛,不管是客户端程序,还是后台服务程序,掌握线程池,是提高业务处理能力的必备模块 本课程将带你从零开始,设计一个支持fixed和cached模式的线程池,玩转C11、…

IGA_PLSM3D的理解1

文章目录前言一、IgaTop3D_FAST.m给的参数二、Material properties 材料特性对Geom_Mod3D的理解对Pre_IGA3D的理解 输出1-----CtrPts: 输出2-----Ele: 输出3-----GauPts:前言 只是为方便学习,不做其他用途 一、IgaTop3D_FAST.m给的…

Python爬虫-某跨境电商(AM)搜索热词

前言 本文是该专栏的第42篇,后面会持续分享python爬虫干货知识,记得关注。 关于某跨境电商(AM),本专栏前面有单独详细介绍过,获取配送地的cookie信息以及商品库存数据,感兴趣的同学可往前翻阅。 1. python爬虫|爬取某跨境电商AM的商品库存数据(Selenium实战) 2. Seleni…

5.39 综合案例2.0 - STM32蓝牙遥控小车1(手机APP遥控)

综合案例2.0 - 蓝牙遥控小车1- 手机APP遥控成品展示案例说明器件说明连线小车源码手机遥控APPAPP使用说明成品展示 案例说明 用STM32单片机做了一辆蓝牙控制的麦轮小车,分享一下小车的原理和制作过程。 控制部分分为手机APP,语音模块控制,Ha…

15-721 chapter2 内存数据库

Background 随着时代的发展,DRAM可以容纳足够的便宜,容量也变大了。对于数据库来说,数据完全可以fit in memory,但同时面向disk的数据库架构不能很好的发挥这个特性 这张图是disk database的cpu instruction cost 想buffer pool…

第5章 继承-Java核心技术·卷1

文章目录Java与C不同基本概念继承:基于已有的类创建新的类。构造器多态定义超类变量可以引用所有的子类对象,但子类变量不能引用超类对象。子类引用的数组可以转换成超类引用的数组覆写返回子类型强制类型转换阻止继承:final类和方法多态 vs …

ROS学习-ROS简介

文章目录1.ROS1.1 ROS概念1.2 ROS特征1.3 ROS特点1.4 ROS版本1.5 ROS程序其他名词介绍1. 元操作系统2. IDL 接口定义语言一些网站1.ROS 1.1 ROS概念 ROS(Robot Operating System,机器人操作系统) ROS 是一个适用于机器人的开源的元操作系统,提供一系列…

linux驱动开发 - 04_Linux 设备树学习 - DTS语法

文章目录Linux 设备树学习 - DTS语法1 什么是设备树?2 DTS、DTB和DTC3 DTS 语法3.1 dtsi 头文件3.2 设备节点3.3 标准属性1、compatible 属性2、model 属性3、status 属性4、#address-cells 和#size-cells 属性5、reg 属性6、ranges 属性7、name 属性8、device_type…