当前位置: 代码迷 >> 综合 >> 【trick 3】Focal Loss —— 解决one-stage目标检测中正负样本不均衡的问题
  详细解决方案

【trick 3】Focal Loss —— 解决one-stage目标检测中正负样本不均衡的问题

热度:92   发布时间:2023-12-14 09:39:08.0

来自:

https://arxiv.org/abs/1708.02002

目录

  • 一、提出背景
  • 二、设计思路
  • 三、总结优缺点
  • 四、PyTorch实现
  • Reference

一、提出背景

目前目标检测的框架一般分为两种:基于候选区域的two-stage的检测框架(比如fast r-cnn系列),基于回归的one-stage的检测框架(yolo,ssd这种),two-stage的速度较慢但效果好,one-stage的速度快但效果差一些。

对于one-stage的检测器准确率不高的问题,论文作者给出了解释:由于正负样本不均衡的问题(感觉理解成简单-难分样本不均衡比较好)。 什么意思呢,就是说one-stage中能够匹配到目标的候选框(正样本)个数一般只用十几个或几十个,而没匹配到的候选框(负样本)大概有104?10510^4 - 10^5104?105个。而负样本大多数都是简单易分的,对训练起不到什么作用,但是数量太多会淹没掉少数但是对训练有帮助的困难样本。

其实two-stage的目标检测也存在着正负样本不均衡的问题,我们都知道,two-stage目标检测是将整个检测过程分成两部分,第一部分先选出一些候选框(如faster rcnn 的rpn),大概2000个左右,在第二阶段再进行筛选,虽然这时正负样本也是存在不均衡的,但是(十几个或几十个:2000) 相对(十几个或几十个:104?10510^4 - 10^5104?105),这时候好了太多了,所以我们的Focal Loss主要针对的是one-stage的目标检测算法。

那么正负样本不均衡,会带来什么问题呢?

  1. 训练效率低下。 training is inefficient as most locations are easy negatives that contribute no useful learning signal;
  2. 模型精度变低。 过多的负样本会主导训练,使模型退化。en masse,the easy negatives can overwhelm training and lead to degenerate models.

针对上述问题,一般的解决方法是难例挖掘(hard negative mining),不过该论文提出了一种新型的Loss函数,试着解决这个问题。

二、设计思路

Focal Loss设计的一个主要的思路就是:希望那些hard examples对损失的贡献变大,使网络更倾向于从这些样本上学习。防止由于easy examples过多,主导整个损失函数。

作者先以二分类为例进行说明:
先看看我们最常用的交叉熵损失函数:

CE(y,p)={?log(p)y=1?log(1?p)其他CE(y,p)=\begin{cases} -log(p) & y=1 \\ -log(1-p) & 其他 \end{cases}CE(y,p)={ ?log(p)?log(1?p)?y=1?

其中 y为真实标签,p为预测概率。
为了简便也可以写为:

pt={py=11?p其他p_t=\begin{cases} p & y=1 \\ 1-p & 其他 \end{cases}pt?={ p1?p?y=1? and rewirte CE(y,p)=CE(pt)=?log(pt)CE(y,p)=CE(p_t)=-log(p_t)CE(y,p)=CE(pt?)=?log(pt?)

要对类别不均衡问题对loss的贡献进行一个控制,即加上一个控制权重即可,最初作者的想法即如下这样,对于属于少数类别的样本,增大α即可:

CE(pt)=?αtlog(pt)CE(p_t)=- \alpha_t log(p_t)CE(pt?)=?αt?log(pt?)
αt={αy=11?α其他;其中α∈[0,1]\alpha_t=\begin{cases} \alpha & y=1 \\ 1-\alpha & 其他 \end{cases}; 其中\alpha\in[0, 1]αt?={ α1?α?y=1?;α[0,1]

注意:这里的αt\alpha_tαt?并不是正负样本的比例,而是一个超参数,用来平衡正负样本的权重。

但是上式只是解决了正负样本之间的平衡问题,并没有区分易分/难分样本,因此就有了下面的公式:

FL(pt)=?(1?pt)γlog(pt)FL(p_t)=- (1-p_t)^\gamma log(p_t)FL(pt?)=?(1?pt?)γlog(pt?)

分析:

  1. 简单样本: 容易预测正确。当y=1(正), ppp->1,ptp_tpt? ->1,(1?pt)γ(1-p_t)^\gamma(1?pt?)γ->0, loss小;当y为其他时(负),ppp->0,ptp_tpt?->1,(1?pt)γ(1-p_t)^\gamma(1?pt?)γ->0, loss小;所以综合来看,当样本为简单样本的时候,损失会比原来的损失小很多倍。
  2. 复杂样本: 容易预测错误。当y=1(正),ppp->0,ptp_tpt? ->0,(1?pt)γ(1-p_t)^\gamma(1?pt?)γ->1,loss下降一点点(几乎不变);当y为其他时(负),ppp->1,ptp_tpt?->0,(1?pt)γ(1-p_t)^\gamma(1?pt?)γ->1,loss下降一点点(几乎不变)。所以综合来看,当样本为复杂样本时,损失和原来的损失差不多,不会小太多。
    举例:前面4行是简单样本(数量很多),使用FL损失函数使其损失值下降了很多倍(相比CE损失函数);而后面两个复杂样本(数量较少),使用FL损失函数后损失值只下降了很少倍。
    在这里插入图片描述

所以γ\gammaγ参数是用来区分易分/难分样本的。它可以通过降低简单样本(数量多)的损失权重,使损失函数更加专注于困难样本(数量),防止简单样本主导整个损失函数。

综合两个方面,最终的损失函数为:

FL(pt)=?αt(1?pt)γlog(pt)FL(p_t)=- \alpha_t (1-p_t)^\gamma log(p_t)FL(pt?)=?αt?(1?pt?)γlog(pt?)
pt={py=11?p其他p_t=\begin{cases} p & y=1 \\ 1-p & 其他 \end{cases}pt?={ p1?p?y=1?
αt={αy=1(正样本)1?α其他(负样本);其中α∈[0,1]\alpha_t=\begin{cases} \alpha & y=1(正样本) \\ 1-\alpha & 其他(负样本) \end{cases}; 其中\alpha\in[0, 1]αt?={ α1?α?y=1()()?;α[0,1]

其中αt\alpha_tαt?来协调正负样本之间的平衡,γ\gammaγ来降低简单样本的权重,使损失函数更关注困难样本。

举例说明:
在这里插入图片描述
如上图,横坐标代表ptp_tpt?,纵坐标表示各种样本所占的loss权重。对于正样本,我们希望ppp越接近1越好,也就是ptp_tpt?越接近1越好;对于负样本,我们希望ppp越接近0越好,也就是ptp_tpt?越接近1越好。所以不管是正样本还是负样本,我们总是希望他预测得到的ptp_tpt?越大越好。如上图所示,pt∈[0.6,1]p_t\in[0.6, 1]pt?[0.6,1]就是我们预测效果比较好的样本(也就是易分样本)了。
显然可以想象这部分的样本数量很多,所以占比是比较高的(如图中蓝色线区域),我们用(1?pt)γ(1-p_t)^\gamma(1?pt?)γ来降低易分样本的损失占比 / 损失贡献(如图其他颜色的曲线)。

三、总结优缺点

优点:

  1. 解决了one-stage object detection中图片中正负样本(前景和背景)不均衡的问题;
  2. 降低简单样本的权重,使损失函数更关注困难样本;

缺点:

  1. 模型很容易收到噪声干扰:会将噪声当成复杂样本,使模型过拟合退化;
  2. 模型的初期,数量多的一类可能主导整个loss,所以训练初期可能训练不稳定;
  3. 两个参数αt\alpha_tαt?γ\gammaγ具体的值很难定义,需要自己调参,调的不好可能效果会更差(论文中的αt\alpha_tαt?=0.25,γ\gammaγ=2最好)。

    在这里插入图片描述

四、PyTorch实现

在yolo_v3_spp的实现代码。

class FocalLoss(nn.Module):# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)def __init__(self, loss_fcn, gamma=2, alpha=0.25):super(FocalLoss, self).__init__()self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss()self.gamma = gamma   # 参数gammaself.alpha = alpha   # 参数alpha# reduction: 控制损失输出模式 sum/mean/none 这里定义的交叉熵损失BCE都是meanself.reduction = loss_fcn.reductionself.loss_fcn.reduction = 'none'  # 不知道这句有什么用? required to apply FL to each elementdef forward(self, pred, true):loss = self.loss_fcn(pred, true)  # 普通BCE Loss# p_t = torch.exp(-loss)# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.pypred_prob = torch.sigmoid(pred)  # prob from logits 如果模型最后没有 nn.Sigmoid(),那么这里就需要对预测结果计算一次 Sigmoid 操作# ture=0,p_t=1-p; true=1, p_t=pp_t = true * pred_prob + (1 - true) * (1 - pred_prob)# ture=0, alpha_factor=1-alpha; true=1,alpha_factor=alphaalpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)modulating_factor = (1.0 - p_t) ** self.gamma# loss = focal loss(代入公式即可)loss *= alpha_factor * modulating_factorif self.reduction == 'mean': # 一般是meanreturn loss.mean()elif self.reduction == 'sum':return loss.sum()else:  # 'none'return loss

Reference

  1. https://blog.csdn.net/Code_Mart/article/details/89736187
  2. https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
  相关解决方案