目录
- 前言
- 一、原理
- 二、实现
前言
马上要找工作了,想总结下自己做过的几个小项目。
先总结下实验室之前的一个病虫害检测相关的项目。选用的baseline是SSD,代码是在这个仓库的基础上改的 lufficc/SSD.这个仓库写的ssd还是很牛的,github有1.3k个star。
选择这个版本的代码,主要有两个原因:
- 它的backbone代码是支持直接加载pytorch官方预训练权重的,所以很方便我做实验
代码高度模块化,类似mmdetection和Detectron2,写的很高级,不过对初学者不是很友好,但是很能提高工程代码能力。 - 原仓库主要实现了SSD-VGG16、SSD-Mobilenet-V2、SSD-Mobilenet-V3、SSD-EfficientNet等网络,在我数据集上几个改进版本都还不如SSD-VGG16效果好,所以我在原仓库的基础上进行了自己的实验,加了一些也不算很高级的trick吧,主要是在我的数据集上确实好使,疯狂调参,哈哈哈。
同系列讲解:
【项目一、xxx病虫害检测项目】1、SSD原理和源码分析.
【项目一、xxx病虫害检测项目】2、网络结构尝试改进:Resnet50、SE、CBAM、Feature Fusion.
第三篇,关于损失函数的尝试:Focal loss。
代码已全部上传GitHub HuKai-cv/FFSSD-ResNet.
一、原理
来自论文:Focal Loss for Dense Object Detection.
背景:one-stage目标检测网络通常精度会低于two stage目标检测模型,原因是:1、one-stage网络正负样本(前景背景)不均衡;2、难易样本不平衡,被背景中有大量easy样本。导致模型训练效率低下、模型精度降低
Focal Loss主要是在CE Loss的基础上做了两件事。第一件是要平衡正负样本,做法是在CE损失的基础上乘以一个α权重(实验最佳0.25),就是正样本:负样本=1:3,这样就解决了正负样本的一个不均衡问题;第二件事是平衡难易样本,因为负样本远大于样本,肯定会有很多的easy负样本,解决办法就是在CE Loss的基础上乘以(1?pt)^γ,可以发现难样本带入公式的损失值几乎不变,简单样本带入损失值会缩小很多倍,所以就起到了一个平衡难易样本的一个作用。
作用:1、平衡正负样本(前景背景);2、平衡难易样本
缺点:1、需要调参,α和γ两个超参不好确定
二、实现
对应代码在ssd/modeling/box_head/loss.py中:
class FocalLoss(nn.Module):def __init__(self, gamma=0, alpha=None, size_average=True):super(FocalLoss, self).__init__()self.gamma = gammaself.alpha = alphaif isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)self.size_average = size_averagedef forward(self, input, target):if input.dim()>2:input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*Winput = input.transpose(1,2) # N,C,H*W => N,H*W,Cinput = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,Ctarget = target.view(-1,1)logpt = F.log_softmax(input)logpt = logpt.gather(1,target)logpt = logpt.view(-1)pt = logpt.data.exp()if self.alpha is not None:if self.alpha.type()!=input.data.type():self.alpha = self.alpha.type_as(input.data)at = self.alpha.gather(0,target.data.view(-1))logpt = logpt * atloss = -1 * (1-pt)**self.gamma * logptif self.size_average: return loss.mean()else: return loss.sum()