来自:
https://arxiv.org/abs/1708.02002
目录
- 一、提出背景
- 二、设计思路
- 三、总结优缺点
- 四、PyTorch实现
- Reference
一、提出背景
目前目标检测的框架一般分为两种:基于候选区域的two-stage的检测框架(比如fast r-cnn系列),基于回归的one-stage的检测框架(yolo,ssd这种),two-stage的速度较慢但效果好,one-stage的速度快但效果差一些。
对于one-stage的检测器准确率不高的问题,论文作者给出了解释:由于正负样本不均衡的问题(感觉理解成简单-难分样本不均衡比较好)。 什么意思呢,就是说one-stage中能够匹配到目标的候选框(正样本)个数一般只用十几个或几十个,而没匹配到的候选框(负样本)大概有104?10510^4 - 10^5104?105个。而负样本大多数都是简单易分的,对训练起不到什么作用,但是数量太多会淹没掉少数但是对训练有帮助的困难样本。
其实two-stage的目标检测也存在着正负样本不均衡的问题,我们都知道,two-stage目标检测是将整个检测过程分成两部分,第一部分先选出一些候选框(如faster rcnn 的rpn),大概2000个左右,在第二阶段再进行筛选,虽然这时正负样本也是存在不均衡的,但是(十几个或几十个:2000) 相对(十几个或几十个:104?10510^4 - 10^5104?105),这时候好了太多了,所以我们的Focal Loss主要针对的是one-stage的目标检测算法。
那么正负样本不均衡,会带来什么问题呢?
- 训练效率低下。 training is inefficient as most locations are easy negatives that contribute no useful learning signal;
- 模型精度变低。 过多的负样本会主导训练,使模型退化。en masse,the easy negatives can overwhelm training and lead to degenerate models.
针对上述问题,一般的解决方法是难例挖掘(hard negative mining),不过该论文提出了一种新型的Loss函数,试着解决这个问题。
二、设计思路
Focal Loss设计的一个主要的思路就是:希望那些hard examples对损失的贡献变大,使网络更倾向于从这些样本上学习。防止由于easy examples过多,主导整个损失函数。
作者先以二分类为例进行说明:
先看看我们最常用的交叉熵损失函数:
其中 y为真实标签,p为预测概率。
为了简便也可以写为:
要对类别不均衡问题对loss的贡献进行一个控制,即加上一个控制权重即可,最初作者的想法即如下这样,对于属于少数类别的样本,增大α即可:
αt={αy=11?α其他;其中α∈[0,1]\alpha_t=\begin{cases} \alpha & y=1 \\ 1-\alpha & 其他 \end{cases}; 其中\alpha\in[0, 1]αt?={ α1?α?y=1其他?;其中α∈[0,1]
注意:这里的αt\alpha_tαt?并不是正负样本的比例,而是一个超参数,用来平衡正负样本的权重。
但是上式只是解决了正负样本之间的平衡问题,并没有区分易分/难分样本,因此就有了下面的公式:
分析:
- 简单样本: 容易预测正确。当y=1(正), ppp->1,ptp_tpt? ->1,(1?pt)γ(1-p_t)^\gamma(1?pt?)γ->0, loss小;当y为其他时(负),ppp->0,ptp_tpt?->1,(1?pt)γ(1-p_t)^\gamma(1?pt?)γ->0, loss小;所以综合来看,当样本为简单样本的时候,损失会比原来的损失小很多倍。
- 复杂样本: 容易预测错误。当y=1(正),ppp->0,ptp_tpt? ->0,(1?pt)γ(1-p_t)^\gamma(1?pt?)γ->1,loss下降一点点(几乎不变);当y为其他时(负),ppp->1,ptp_tpt?->0,(1?pt)γ(1-p_t)^\gamma(1?pt?)γ->1,loss下降一点点(几乎不变)。所以综合来看,当样本为复杂样本时,损失和原来的损失差不多,不会小太多。
举例:前面4行是简单样本(数量很多),使用FL损失函数使其损失值下降了很多倍(相比CE损失函数);而后面两个复杂样本(数量较少),使用FL损失函数后损失值只下降了很少倍。
所以γ\gammaγ参数是用来区分易分/难分样本的。它可以通过降低简单样本(数量多)的损失权重,使损失函数更加专注于困难样本(数量),防止简单样本主导整个损失函数。
综合两个方面,最终的损失函数为:
pt={py=11?p其他p_t=\begin{cases} p & y=1 \\ 1-p & 其他 \end{cases}pt?={ p1?p?y=1其他?
αt={αy=1(正样本)1?α其他(负样本);其中α∈[0,1]\alpha_t=\begin{cases} \alpha & y=1(正样本) \\ 1-\alpha & 其他(负样本) \end{cases}; 其中\alpha\in[0, 1]αt?={ α1?α?y=1(正样本)其他(负样本)?;其中α∈[0,1]
其中αt\alpha_tαt?来协调正负样本之间的平衡,γ\gammaγ来降低简单样本的权重,使损失函数更关注困难样本。
举例说明:
如上图,横坐标代表ptp_tpt?,纵坐标表示各种样本所占的loss权重。对于正样本,我们希望ppp越接近1越好,也就是ptp_tpt?越接近1越好;对于负样本,我们希望ppp越接近0越好,也就是ptp_tpt?越接近1越好。所以不管是正样本还是负样本,我们总是希望他预测得到的ptp_tpt?越大越好。如上图所示,pt∈[0.6,1]p_t\in[0.6, 1]pt?∈[0.6,1]就是我们预测效果比较好的样本(也就是易分样本)了。
显然可以想象这部分的样本数量很多,所以占比是比较高的(如图中蓝色线区域),我们用(1?pt)γ(1-p_t)^\gamma(1?pt?)γ来降低易分样本的损失占比 / 损失贡献(如图其他颜色的曲线)。
三、总结优缺点
优点:
- 解决了one-stage object detection中图片中正负样本(前景和背景)不均衡的问题;
- 降低简单样本的权重,使损失函数更关注困难样本;
缺点:
- 模型很容易收到噪声干扰:会将噪声当成复杂样本,使模型过拟合退化;
- 模型的初期,数量多的一类可能主导整个loss,所以训练初期可能训练不稳定;
- 两个参数αt\alpha_tαt?和γ\gammaγ具体的值很难定义,需要自己调参,调的不好可能效果会更差(论文中的αt\alpha_tαt?=0.25,γ\gammaγ=2最好)。
四、PyTorch实现
在yolo_v3_spp的实现代码。
class FocalLoss(nn.Module):# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)def __init__(self, loss_fcn, gamma=2, alpha=0.25):super(FocalLoss, self).__init__()self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()self.gamma = gamma # 参数gammaself.alpha = alpha # 参数alpha# reduction: 控制损失输出模式 sum/mean/none 这里定义的交叉熵损失BCE都是meanself.reduction = loss_fcn.reductionself.loss_fcn.reduction = 'none' # 不知道这句有什么用? required to apply FL to each elementdef forward(self, pred, true):loss = self.loss_fcn(pred, true) # 普通BCE Loss# p_t = torch.exp(-loss)# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.pypred_prob = torch.sigmoid(pred) # prob from logits 如果模型最后没有 nn.Sigmoid(),那么这里就需要对预测结果计算一次 Sigmoid 操作# ture=0,p_t=1-p; true=1, p_t=pp_t = true * pred_prob + (1 - true) * (1 - pred_prob)# ture=0, alpha_factor=1-alpha; true=1,alpha_factor=alphaalpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)modulating_factor = (1.0 - p_t) ** self.gamma# loss = focal loss(代入公式即可)loss *= alpha_factor * modulating_factorif self.reduction == 'mean': # 一般是meanreturn loss.mean()elif self.reduction == 'sum':return loss.sum()else: # 'none'return loss
Reference
- https://blog.csdn.net/Code_Mart/article/details/89736187
- https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py