当前位置: 代码迷 >> 综合 >> 【项目一、xxx病虫害检测项目】3、损失函数尝试:Focal loss
  详细解决方案

【项目一、xxx病虫害检测项目】3、损失函数尝试:Focal loss

热度:73   发布时间:2023-12-14 09:19:29.0

目录

  • 前言
  • 一、原理
  • 二、实现

前言

马上要找工作了,想总结下自己做过的几个小项目。

先总结下实验室之前的一个病虫害检测相关的项目。选用的baseline是SSD,代码是在这个仓库的基础上改的 lufficc/SSD.这个仓库写的ssd还是很牛的,github有1.3k个star。

选择这个版本的代码,主要有两个原因:

  1. 它的backbone代码是支持直接加载pytorch官方预训练权重的,所以很方便我做实验
    代码高度模块化,类似mmdetection和Detectron2,写的很高级,不过对初学者不是很友好,但是很能提高工程代码能力。
  2. 原仓库主要实现了SSD-VGG16、SSD-Mobilenet-V2、SSD-Mobilenet-V3、SSD-EfficientNet等网络,在我数据集上几个改进版本都还不如SSD-VGG16效果好,所以我在原仓库的基础上进行了自己的实验,加了一些也不算很高级的trick吧,主要是在我的数据集上确实好使,疯狂调参,哈哈哈。

同系列讲解:
【项目一、xxx病虫害检测项目】1、SSD原理和源码分析.
【项目一、xxx病虫害检测项目】2、网络结构尝试改进:Resnet50、SE、CBAM、Feature Fusion.

第三篇,关于损失函数的尝试:Focal loss。

代码已全部上传GitHub HuKai-cv/FFSSD-ResNet.

一、原理

来自论文:Focal Loss for Dense Object Detection.

背景:one-stage目标检测网络通常精度会低于two stage目标检测模型,原因是:1、one-stage网络正负样本(前景背景)不均衡;2、难易样本不平衡,被背景中有大量easy样本。导致模型训练效率低下、模型精度降低

Focal Loss主要是在CE Loss的基础上做了两件事。第一件是要平衡正负样本,做法是在CE损失的基础上乘以一个α权重(实验最佳0.25),就是正样本:负样本=1:3,这样就解决了正负样本的一个不均衡问题;第二件事是平衡难易样本,因为负样本远大于样本,肯定会有很多的easy负样本,解决办法就是在CE Loss的基础上乘以(1?pt)^γ,可以发现难样本带入公式的损失值几乎不变,简单样本带入损失值会缩小很多倍,所以就起到了一个平衡难易样本的一个作用。

在这里插入图片描述

作用:1、平衡正负样本(前景背景);2、平衡难易样本

缺点:1、需要调参,α和γ两个超参不好确定

二、实现

对应代码在ssd/modeling/box_head/loss.py中:

class FocalLoss(nn.Module):def __init__(self, gamma=0, alpha=None, size_average=True):super(FocalLoss, self).__init__()self.gamma = gammaself.alpha = alphaif isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)self.size_average = size_averagedef forward(self, input, target):if input.dim()>2:input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*Winput = input.transpose(1,2)    # N,C,H*W => N,H*W,Cinput = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,Ctarget = target.view(-1,1)logpt = F.log_softmax(input)logpt = logpt.gather(1,target)logpt = logpt.view(-1)pt = logpt.data.exp()if self.alpha is not None:if self.alpha.type()!=input.data.type():self.alpha = self.alpha.type_as(input.data)at = self.alpha.gather(0,target.data.view(-1))logpt = logpt * atloss = -1 * (1-pt)**self.gamma * logptif self.size_average: return loss.mean()else: return loss.sum()
  相关解决方案