全部复制的paddleseg的代码转torch
import argparse
import logging
import osimport numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transformsfrom utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask
from torch.utils.data import DataLoader, random_split
from utils.data_loading import BasicDataset, CarvanaDataset
from tqdm import tqdm
import torch.nn.functional as F# 使用python写一个评估使用pytorch训练的unet模型的好坏,模型输出nchw格式的数据,真实标签数据为nhw格式,请计算模型的accuracy, calss precision ,class recall,kappa指标EPSILON = 1e-32def calculate_area(pred, label, num_classes, ignore_index=255):"""Calculate intersect, prediction and label areaArgs:pred (Tensor): The prediction by model.label (Tensor): The ground truth of image.num_classes (int): The unique number of target classes.ignore_index (int): Specifies a target value that is ignored. Default: 255.Returns:Tensor: The intersection area of prediction and the ground on all class.Tensor: The prediction area on all class.Tensor: The ground truth area on all class"""if len(pred.shape) == 4:pred = torch.squeeze(pred, axis=1)if len(label.shape) == 4:label = torch.squeeze(label, axis=1)if not pred.shape == label.shape:raise ValueError('Shape of `pred` and `label should be equal, ''but there are {} and {}.'.format(pred.shape,label.shape))pred_area = []label_area = []intersect_area = []mask = label != ignore_indexfor i in range(num_classes):pred_i = torch.logical_and(pred == i, mask)label_i = label == iintersect_i = torch.logical_and(pred_i, label_i)pred_area.append(torch.sum(pred_i)) label_area.append(torch.sum(label_i)) intersect_area.append(torch.sum(intersect_i)) pred_area = torch.stack(pred_area) label_area = torch.stack(label_area) intersect_area = torch.stack(intersect_area) return intersect_area, pred_area, label_areadef get_args():parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,help='Percent of the data that is used as validation (0-100)')parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')parser.add_argument('--root', '-r', type=str, default=False, help='root dir')parser.add_argument('--num', '-n', type=int, default=False, help='num of classes')return parser.parse_args()dir_img_path = 'imgs'
dir_mask_path = 'masks'import metricsdef train_net(net,device,epochs: int = 5,batch_size: int = 1,learning_rate: float = 0.001,val_percent: float = 0.1,save_checkpoint: bool = True,img_scale: float = 0.5,amp: bool = False,root_dir: str = '/data/yangbo/unet/datas/data1'):train_dir_img=os.path.join(root_dir,'train',dir_img_path)train_dir_mask=os.path.join(root_dir,'train',dir_mask_path)val_dir_img=os.path.join(root_dir,'val',dir_img_path)val_dir_mask=os.path.join(root_dir,'val',dir_mask_path)# 1. Create datasettry:train_dataset = CarvanaDataset(train_dir_img, train_dir_mask, img_scale)val_dataset = CarvanaDataset(val_dir_img, val_dir_mask, img_scale)except (AssertionError, RuntimeError):train_dataset = BasicDataset(train_dir_img, train_dir_mask, img_scale)val_dataset = BasicDataset(val_dir_img, val_dir_mask, img_scale)n_val = len(val_dataset)n_train = len(train_dataset)# 3. Create data loadersloader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, **loader_args)# (Initialize logging)logging.info(f'''Starting training:Epochs: {epochs}Batch size: {batch_size}Learning rate: {learning_rate}Training size: {n_train}Validation size: {n_val}Checkpoints: {save_checkpoint}Device: {device.type}Images scaling: {img_scale}Mixed Precision: {amp}''')# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP#optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score# 5. Begin trainingintersect_area_all=torch.zeros([1])pred_area_all=torch.zeros([1])label_area_all=torch.zeros([1])for idx,batch in tqdm(enumerate(val_loader)):images = batch['image']true_masks = batch['mask']assert images.shape[1] == net.n_channels, \f'Network has been defined with {net.n_channels} input channels, ' \f'but loaded images have {images.shape[1]} channels. Please check that ' \'the images are loaded correctly.'images = images.to(device=device, dtype=torch.float32)true_masks = true_masks.to(device=device, dtype=torch.long)with torch.no_grad():masks_pred = net(images)masks_pred=torch.argmax(masks_pred,axis=1,keepdim=True)intersect_area, pred_area, label_area=calculate_area(masks_pred,true_masks,3)intersect_area_all = intersect_area_all + intersect_areapred_area_all = pred_area_all + pred_arealabel_area_all = label_area_all + label_areametrics_input = (intersect_area_all, pred_area_all, label_area_all)class_iou, miou = metrics.mean_iou(*metrics_input)acc, class_precision, class_recall = metrics.class_measurement(*metrics_input)kappa = metrics.kappa(*metrics_input)class_dice, mdice = metrics.dice(*metrics_input)infor="[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format(len(val_loader), miou, acc, kappa, mdice)print(infor)print("[EVAL] Class IoU: " + str(np.round(class_iou, 4)))print("[EVAL] Class Precision: " + str(np.round(class_precision, 4)))print("[EVAL] Class Recall: " + str(np.round(class_recall, 4)))if __name__ == '__main__':args = get_args()logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device {device}')# Change here to adapt to your data# n_channels=3 for RGB images# n_classes is the number of probabilities you want to get per pixel# 修改numclassnet = UNet(n_channels=3, n_classes=args.num, bilinear=True)net.eval()logging.info(f'Network:\n'f'\t{net.n_channels} input channels\n'f'\t{net.n_classes} output channels (classes)\n'f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')if args.load:net.load_state_dict(torch.load(args.load, map_location=device))logging.info(f'Model loaded from {args.load}')net.to(device=device)try:train_net(net=net,epochs=0,batch_size=args.batch_size,learning_rate=0,device=device,img_scale=args.scale,val_percent=args.val / 100,amp=args.amp,root_dir=args.root)except KeyboardInterrupt:torch.save(net.state_dict(), 'INTERRUPTED.pth')logging.info('Saved interrupt')
metris.py
import numpy as np
import torch
import sklearn.metrics as skmetricsdef mean_iou(intersect_area, pred_area, label_area):"""Calculate iou.Args:intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.pred_area (Tensor): The prediction area on all classes.label_area (Tensor): The ground truth area on all classes.Returns:np.ndarray: iou on all classes.float: mean iou of all classes."""intersect_area = intersect_area.numpy()pred_area = pred_area.numpy()label_area = label_area.numpy()union = pred_area + label_area - intersect_areaclass_iou = []for i in range(len(intersect_area)):if union[i] == 0:iou = 0else:iou = intersect_area[i] / union[i]class_iou.append(iou)miou = np.mean(class_iou)return np.array(class_iou), mioudef class_measurement(intersect_area, pred_area, label_area):"""Calculate accuracy, calss precision and class recall.Args:intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.pred_area (Tensor): The prediction area on all classes.label_area (Tensor): The ground truth area on all classes.Returns:float: The mean accuracy.np.ndarray: The precision of all classes.np.ndarray: The recall of all classes."""intersect_area = intersect_area.numpy()pred_area = pred_area.numpy()label_area = label_area.numpy()mean_acc = np.sum(intersect_area) / np.sum(pred_area)class_precision = []class_recall = []for i in range(len(intersect_area)):precision = 0 if pred_area[i] == 0 \else intersect_area[i] / pred_area[i]recall = 0 if label_area[i] == 0 \else intersect_area[i] / label_area[i]class_precision.append(precision)class_recall.append(recall)return mean_acc, np.array(class_precision), np.array(class_recall)def kappa(intersect_area, pred_area, label_area):"""Calculate kappa coefficientArgs:intersect_area (Tensor): The intersection area of prediction and ground truth on all classes..pred_area (Tensor): The prediction area on all classes.label_area (Tensor): The ground truth area on all classes.Returns:float: kappa coefficient."""intersect_area = intersect_area.numpy().astype(np.float64)pred_area = pred_area.numpy().astype(np.float64)label_area = label_area.numpy().astype(np.float64)total_area = np.sum(label_area)po = np.sum(intersect_area) / total_areape = np.sum(pred_area * label_area) / (total_area * total_area)kappa = (po - pe) / (1 - pe)return kappadef dice(intersect_area, pred_area, label_area):"""Calculate DICE.Args:intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.pred_area (Tensor): The prediction area on all classes.label_area (Tensor): The ground truth area on all classes.Returns:np.ndarray: DICE on all classes.float: mean DICE of all classes."""intersect_area = intersect_area.numpy()pred_area = pred_area.numpy()label_area = label_area.numpy()union = pred_area + label_areaclass_dice = []for i in range(len(intersect_area)):if union[i] == 0:dice = 0else:dice = (2 * intersect_area[i]) / union[i]class_dice.append(dice)mdice = np.mean(class_dice)return np.array(class_dice), mdice
使用示例
python .\test2.py --root D:\pic\23\0403\851-1003339-H01\bend --scale 0.25 --load C:\Users\Admin\Desktop\fsdownload\checkpoint_epoch485.pth --num 3
结果展示
[EVAL] #Images: 74 mIoU: 0.5119 Acc: 0.9996 Kappa: 0.4405 Dice: 0.6002
[EVAL] Class IoU: [0.9997 0.4177 0.1183]
[EVAL] Class Precision: [0.9998 0.6767 0.1858]
[EVAL] Class Recall: [0.9998 0.5219 0.2456]