当前位置: 代码迷 >> 综合 >> 模型评估:准确率(Accuracy),精确率(Precision),召回率(Recall),F1-Score
  详细解决方案

模型评估:准确率(Accuracy),精确率(Precision),召回率(Recall),F1-Score

热度:46   发布时间:2023-10-11 07:04:32.0

     机器学习中对分类器的评估参考以下的评价指标,主要包括准确率(Accuracy),精确率(Precision),召回率(Recall),F1-Score。接下来的描述主要是以二分类举例,即label为0和1的情况。

    (一)理解TP, TN, FP, FN

      首先需要明确这几个值的定义: 
      True Positive(真正, TP):将正类预测为正类数. 
      True Negative(真负, TN):将负类预测为负类数. 
      False Positive(假正, FP):将负类预测为正类数 →→ 误报 (Type I error). 
      False Negative(假负 , FN):将正类预测为负类数 →→ 漏报 (Type II error). 
      即:

  Positive Negative
True TP FP
False FN TN

       用代码直观的表示,如下:       

if (pred.view_as(data == 1) & (data.detach().cpu() == 1):# TP  predict & label == 1tp += 1
if (pred.view_as(data == 0) & (data.detach().cpu() == 0):# TN  predict & label == 0tn += 1
if (pred.view_as(data == 1) & (data.detach().cpu() == 0):# FN  predict == 0 & label == 1fn += 1fn_list.append(np.array(data.detach().cpu()))
if (pred.view_as(data == 0) & (data.detach().cpu() == 1):# FP  predict == 1 & label == 0fp += 1fp_list.append(np.array(data.detach().cpu()))

     (二)理解准确率(Accuracy)和精确率(Precision)

      精确率(precision)和准确率(accuracy)是不一样的,我们在做模型评估是需要搞清楚两者的定义。

      准确率是针对我们原来所有样本而言的,它表示的是所有样本有多少被准确预测了,即:

                                                                     Acc = (tp + tn) / (tp + tn + fp + fn)

      精确率是针对我们预测结果而言的,它表示的是预测为正的样本中有多少是真正的正样本。那么预测为正就有两种可能了,一种就是把正类预测为正类(TP),另一种就是把负类预测为正类(FP),即:

                                                                       P = tp / (tp + fp)

“ 预测为正例的里面有多少是对的”。

    (三)理解精确率(Precision)和召回率(Recall)

      召回率是针对我们原来的正样本而言的,它表示的是正例样本中有多少被预测正确了。那也有两种可能,一种是把原来的正类预测成正类(TP),另一种就是把原来的正类预测为负类(FN)。即:

                                                                       R = tp / (tp + fn)

      准确率召回率是广泛用于信息检索和统计学分类领域的两个度量值,用来评价结果的质量。其中精度是检索出相关文档数与检索出的文档总数的比率,衡量的是检索系统的查准率;召回率是指检索出的相关文档数和文档库中所有的相关文档数的比率,衡量的是检索系统的查全率

      一般来说,Precision就是检索出来的条目(比如:文档、网页等)有多少是准确的,Recall就是所有准确的条目有多少被检索出来了。

      正确率、召回率和 F 值是在鱼龙混杂的环境中,选出目标的重要评价指标。不妨看看这些指标的定义先:

      1. 正确率 = 提取出的正确信息条数 / 提取出的信息条数

      2. 召回率 = 提取出的正确信息条数 / 样本中的信息条数

     两者取值在0和1之间,数值越接近1,查准率或查全率就越高。

模型评估:准确率(Accuracy),精确率(Precision),召回率(Recall),F1-Score

     (四)理解F1-Score

       F1分数(F1-score)是分类问题的一个衡量指标。一些多分类问题的机器学习竞赛,常常将F1-score作为最终测评的方法。它是精确率和召回率的调和平均数,最大为1,最小为0。

       F1的计算:F1-score = 2 ? Precision ? Recall / (Precision + Recall)

     (五)使用sklearn库来实现模型评估 (以下提供计算方法)

       使用pip命令下载安装sklearn库

pip install scikit-learn

       1. 使用tn,fp,fn,tp来计算相应的指标       

from sklearn.metrics import confusion_matrixdef test(self):self.model.load_weights(self.name + "_model.pkl")values = self.model.evaluate(self.x_test, self.y_test, batch_size=self.batch_size)print("Accuracy: ", values[1])predictions = (self.model.predict(self.x_test, batch_size=self.batch_size)).round()tn, fp, fn, tp = confusion_matrix(np.argmax(self.y_test, axis=1), np.argmax(predictions, axis=1)).ravel()print('False positive rate(FP): ', fp / (fp + tn))print('False negative rate(FN): ', fn / (fn + tp))recall = tp / (tp + fn)print('Recall: ', recall)precision = tp / (tp + fp)print('Precision: ', precision)print('F1 score: ', (2 * precision * recall) / (precision + recall))

          2. 直接计算 recall和precision

from sklearn import metricsdef test(test_loader):model.eval()start = time.time()test_loss, correct, n_samples, count = 0, 0, 0, 0accuracy, recall, precision, F1 = 0, 0, 0, 0for batch_idx, data in enumerate(test_loader):count += 1for i in range(len(data)):data[i] = data[i].to(args.device)output = model(data)loss = loss_fn(output, data[4], reduction='sum')test_loss += loss.item()n_samples += len(output)pred = output.detach().cpu().max(1, keepdim=True)[1]correct += pred.eq(data[4].detach().cpu().view_as(pred)).sum().item()accuracy += metrics.accuracy_score(data[4], pred)recall += metrics.recall_score(data[4], pred, average="macro")precision += metrics.precision_score(data[4], pred, average="macro")F1 += metrics.f1_score(data[4], pred, average="macro")acc = 100. * correct / n_samplesaccuracy = 100. * accuracy / count recall = 100. * recall / countprecision = 100. * precision / countF1 = 100. * F1 / count

  关于上述的具体代码,可以参考:

  https://github.com/Messi-Q/VulDeeSmartContract

  https://github.com/Messi-Q/GraphDeeSmartContract

  相关解决方案