当前位置: 代码迷 >> 综合 >> pytorch下实现mIou(mean intersection over union)和pA(pixel accuracy)
  详细解决方案

pytorch下实现mIou(mean intersection over union)和pA(pixel accuracy)

热度:37   发布时间:2023-12-14 15:01:17.0

mIou

import torch
import numpy as np
def Iou(input,target,classNum):''':param input: [b,h,w]:param target: [b,h,w]:param classNum: scalar:return:'''inputTmp = torch.zeros([input.shape[0],classNum,input.shape[1],input.shape[2]])#创建[b,c,h,w]大小的0矩阵targetTmp = torch.zeros([target.shape[0],classNum,target.shape[1],target.shape[2]])#同上input = input.unsqueeze(1)#将input维度扩充为[b,1,h,w]target = target.unsqueeze(1)#同上inputOht = inputTmp.scatter_(index=input,dim=1,value=1)#input作为索引,将0矩阵转换为onehot矩阵targetOht = targetTmp.scatter_(index=target,dim=1,value=1)#同上batchMious = []#为该batch中每张图像存储一个mioumul = inputOht * targetOht#乘法计算后,其中1的个数为intersectionfor i in range(input.shape[0]):#遍历图像ious = []for j in range(classNum):#遍历类别,包括背景intersection = torch.sum(mul[i][j])union = torch.sum(inputOht[i][j]) + torch.sum(targetOht[i][j]) - intersection + 1e-6iou = intersection / unionious.append(iou)miou = np.mean(ious)#计算该图像的mioubatchMious.append(miou)return batchMious

pA:对单张图像直接计算pa,没有进行分类计算取平均

def Pa(input,target):''':param input: [b,h,w]:param target: [b,h,w]:param classNum: scalar:return:'''tmp = input == targetreturn (torch.sum(tmp).float() / input.nelement())

 

  相关解决方案