当前位置: 代码迷 >> 综合 >> Retinaface代码记录(四)(网络结构)
  详细解决方案

Retinaface代码记录(四)(网络结构)

热度:44   发布时间:2023-11-22 03:44:50.0

一、写在开头

这次主要记录关于Retinaface的网络结构部分。

下面是代码地址:
Retinaface代码地址

主要包括的脚本为:
retinaface.py
net.py

也欢迎阅读其上一篇博客Retinaface代码记录(一)。可以帮助读者对本片博客可以有一个整体上的把握和理解。

二、主要内容

Fig1所示,这是Retinaface的网络结构概况图。这里采用的骨干网络是Resnet50或MobileNet,如Fig2。然后是FPN,即特征金字塔网络,一种多尺度object detection算法,多数的object detection算法都是只采用顶层特征做预测,但我们知道低层的特征语义信息比较少,但是目标位置准确;高层的特征语义信息比较丰富,但是目标位置比较粗略。另外虽然也有些算法采用多尺度特征融合的方式,但是一般是采用融合后的特征做预测,而本文不一样的地方在于预测是在不同特征层独立进行的。 常见的有下列几种:如Fig3。最后是SSH,如下图Fig4

Fig1:
在这里插入图片描述
Fig2:

在这里插入图片描述
在这里插入图片描述
Fig3:
(a)图像金字塔,即将图像做成不同的scale,然后不同scale的图像生成对应的不同scale的特征。这种方法的缺点在于增加了时间成本。有些算法会在测试时候采用图像金字塔。
(b)像SPP net,Fast RCNN,Faster RCNN是采用这种方式,即仅采用网络最后一层的特征。
(c)像SSD(Single Shot Detector)采用这种多尺度特征融合的方式,没有上采样过程,即从网络不同层抽取不同尺度的特征做预测,这种方式不会增加额外的计算量。但SSD算法中没有用到足够低层的特征(在SSD中,最低层的特征是VGG网络的conv4_3),而足够低层的特征对于检测小物体是很有帮助的。
(d)即FPN,顶层特征通过上采样和低层特征做融合,而且每层都是独立预测的。

在这里插入图片描述
Fig4
在这里插入图片描述

retinaface.py:

import torch
import torch.nn as nn
import torchvision.models.detection.backbone_utils as backbone_utils
import torchvision.models._utils as _utils
import torch.nn.functional as F
from collections import OrderedDictfrom models.net import MobileNetV1 as MobileNetV1
from models.net import FPN as FPN
from models.net import SSH as SSHclass ClassHead(nn.Module):def __init__(self,inchannels=512,num_anchors=3):super(ClassHead,self).__init__()self.num_anchors = num_anchorsself.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)def forward(self,x):out = self.conv1x1(x)out = out.permute(0,2,3,1).contiguous()return out.view(out.shape[0], -1, 2)class BboxHead(nn.Module):def __init__(self,inchannels=512,num_anchors=3):super(BboxHead,self).__init__()self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)def forward(self,x):out = self.conv1x1(x)out = out.permute(0,2,3,1).contiguous()return out.view(out.shape[0], -1, 4)class LandmarkHead(nn.Module):def __init__(self,inchannels=512,num_anchors=3):super(LandmarkHead,self).__init__()self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)def forward(self,x):out = self.conv1x1(x)out = out.permute(0,2,3,1).contiguous()return out.view(out.shape[0], -1, 10)class RetinaFace(nn.Module):def __init__(self, cfg = None, phase = 'train'):""":param cfg: Network related settings.:param phase: train or test."""super(RetinaFace,self).__init__()self.phase = phasebackbone = Noneif cfg['name'] == 'mobilenet0.25':backbone = MobileNetV1()if cfg['pretrain']:checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))from collections import OrderedDictnew_state_dict = OrderedDict()for k, v in checkpoint['state_dict'].items():name = k[7:]  # remove module.new_state_dict[name] = v# load paramsbackbone.load_state_dict(new_state_dict)elif cfg['name'] == 'Resnet50':import torchvision.models as modelsbackbone = models.resnet50(pretrained=cfg['pretrain'])self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])in_channels_stage2 = cfg['in_channel']in_channels_list = [in_channels_stage2 * 2,in_channels_stage2 * 4,in_channels_stage2 * 8,]out_channels = cfg['out_channel']self.fpn = FPN(in_channels_list,out_channels)self.ssh1 = SSH(out_channels, out_channels)self.ssh2 = SSH(out_channels, out_channels)self.ssh3 = SSH(out_channels, out_channels)self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):classhead = nn.ModuleList()for i in range(fpn_num):classhead.append(ClassHead(inchannels,anchor_num))return classheaddef _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):bboxhead = nn.ModuleList()for i in range(fpn_num):bboxhead.append(BboxHead(inchannels,anchor_num))return bboxheaddef _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):landmarkhead = nn.ModuleList()for i in range(fpn_num):landmarkhead.append(LandmarkHead(inchannels,anchor_num))return landmarkheaddef forward(self,inputs):out = self.body(inputs)# FPNfpn = self.fpn(out)# SSHfeature1 = self.ssh1(fpn[0])feature2 = self.ssh2(fpn[1])feature3 = self.ssh3(fpn[2])features = [feature1, feature2, feature3]bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)if self.phase == 'train':output = (bbox_regressions, classifications, ldm_regressions)else:output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)return output

net.py:

import time
import torch
import torch.nn as nn
import torchvision.models._utils as _utils
import torchvision.models as models
import torch.nn.functional as F
from torch.autograd import Variabledef conv_bn(inp, oup, stride = 1, leaky = 0):return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False),nn.BatchNorm2d(oup),nn.LeakyReLU(negative_slope=leaky, inplace=True))def conv_bn_no_relu(inp, oup, stride):return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False),nn.BatchNorm2d(oup),)def conv_bn1X1(inp, oup, stride, leaky=0):return nn.Sequential(nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),nn.BatchNorm2d(oup),nn.LeakyReLU(negative_slope=leaky, inplace=True))def conv_dw(inp, oup, stride, leaky=0.1):return nn.Sequential(nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),nn.BatchNorm2d(inp),nn.LeakyReLU(negative_slope= leaky,inplace=True),nn.Conv2d(inp, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),nn.LeakyReLU(negative_slope= leaky,inplace=True),)class SSH(nn.Module):def __init__(self, in_channel, out_channel):super(SSH, self).__init__()assert out_channel % 4 == 0leaky = 0if (out_channel <= 64):leaky = 0.1self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)def forward(self, input):conv3X3 = self.conv3X3(input)conv5X5_1 = self.conv5X5_1(input)conv5X5 = self.conv5X5_2(conv5X5_1)conv7X7_2 = self.conv7X7_2(conv5X5_1)conv7X7 = self.conv7x7_3(conv7X7_2)out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)out = F.relu(out)return outclass FPN(nn.Module):def __init__(self,in_channels_list,out_channels):super(FPN,self).__init__()leaky = 0if (out_channels <= 64):leaky = 0.1self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)def forward(self, input):# names = list(input.keys())input = list(input.values())output1 = self.output1(input[0])output2 = self.output2(input[1])output3 = self.output3(input[2])up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")output2 = output2 + up3output2 = self.merge2(output2)up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")output1 = output1 + up2output1 = self.merge1(output1)out = [output1, output2, output3]return outclass MobileNetV1(nn.Module):def __init__(self):super(MobileNetV1, self).__init__()self.stage1 = nn.Sequential(conv_bn(3, 8, 2, leaky = 0.1),    # 3conv_dw(8, 16, 1),   # 7conv_dw(16, 32, 2),  # 11conv_dw(32, 32, 1),  # 19conv_dw(32, 64, 2),  # 27conv_dw(64, 64, 1),  # 43)self.stage2 = nn.Sequential(conv_dw(64, 128, 2),  # 43 + 16 = 59conv_dw(128, 128, 1), # 59 + 32 = 91conv_dw(128, 128, 1), # 91 + 32 = 123conv_dw(128, 128, 1), # 123 + 32 = 155conv_dw(128, 128, 1), # 155 + 32 = 187conv_dw(128, 128, 1), # 187 + 32 = 219)self.stage3 = nn.Sequential(conv_dw(128, 256, 2), # 219 +3 2 = 241conv_dw(256, 256, 1), # 241 + 64 = 301)self.avg = nn.AdaptiveAvgPool2d((1,1))self.fc = nn.Linear(256, 1000)def forward(self, x):x = self.stage1(x)x = self.stage2(x)x = self.stage3(x)x = self.avg(x)# x = self.model(x)x = x.view(-1, 256)x = self.fc(x)return x

三、结尾

上面就是根据代码记录的网络,当然,对网络细节和优劣势了解较少,如有不当的地方,请指出。

下面两篇博客是关于FPN和SSH的一个详细介绍,上文中也有参考下面的,有兴趣的可以看下。
FPN网络结构
SSH网络结构