当前位置: 代码迷 >> 综合 >> 【论文复现】DenseNet(2018)
  详细解决方案

【论文复现】DenseNet(2018)

热度:65   发布时间:2023-12-14 09:36:41.0

在这里插入图片描述

论文: https://arxiv.org/pdf/1608.06993.pdf.
PyTorch实现代码:github链接.

一、背景(动机)

随着卷积神经网络变得越来越深,一个新的研究问题出现了:当输入信息(梯度信息)经过许多层之后,在它到达网络末尾(开端)时,它会消失和“洗净”。这种现象就是我们常说的梯度消失或者说是梯度弥散问题。

对神经网络结构的探索一直是神经网络研究的重要组成部分。近年来也取得了很大的进展,比如对网络支路的探索如ResNet、Highway Networks、Stochastic depth、FractalNets等提供了一个解决思路:那就是建立一个从early layers到later layers的短路径。且被实验证明了这种思路是可行的,有效的。基于这些研究,DenseNet作者从中得到了启发。

DenseNet的思路:既然这种创建从早期层到后期层的短路径(本质是层与层之间的信息流通或者说是特征重用)的方法有用,那我就让各层之间传递信息的能力最大化。即:每一层从前面所有层获得额外的输入,本身又作为这层后面所有层的输入,这样就将所有层直接相互连接起来。如下图:
在这里插入图片描述

二、网络架构

CNN网络一般要经过Pooling或者stride>1的Conv来降低特征图的大小,而DenseNet的密集连接方式需要特征图大小保持一致。为了解决这个问题,DenseNet网络中使用DenseBlock+Transition的结构,其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。而Transition模块是连接两个相邻的DenseBlock,并且通过Pooling使特征图大小降低。

DenseNet的网络就是由一系列的dense block和连接这些block所用的transition layer 组成。如下图:
在这里插入图片描述
下面具体介绍网络的具体实现细节。

2.1、DenseBlock

在DenseBlock中,各个层的特征图大小一致,可以在channel维度上进行concat连接。DenseBlock中的非线性组合函数H(?)H (?)H(?)采用的是一系列的特征操作。

需要注意的是Dense Block采用了激活函数在前、卷积层在后的顺序,即BN-ReLU-Conv的顺序,这种方式也被称为pre-activation。通常的模型relu等激活函数处于卷积conv、批归一化batchnorm之后,即Conv-BN-ReLU,也被称为post-activation。作者证明,如果采用post-activation设计,性能会变差。

比如在DenseNet中就是采用的是 BN + ReLU + Conv(3x3) + Dropout(0.2) 的结构。如下图:
在这里插入图片描述

另外值得注意的一点是,与ResNet不同,所有DenseBlock中各个层卷积均采用 k 个卷积核,所有每个卷积层输出的feature map的channel都为k。k在DenseNet称为growth rate,这是一个超参数。一般情况下使用较小的 k(论文里k=32),就可以得到较佳的性能。假定输入层的特征图的channel数为 k0k_0k0?,那么第 lll 层输入的channel数为 k0+k(l?1)k_0+k(l-1)k0?+k(l?1) ,因此随着层数增加,尽管 kkk 设定得较小,DenseBlock的输入也会非常多,不过这是由于特征重用所造成的,每个层仅有 kkk 个特征是自己独有的。所以,一个L层的DenseNet网络一共有L(L+1)2\frac{L(L+1)}{2}2L(L+1)? 个concat连接。
?
由于后面层的输入会非常大,DenseBlock内部可以采用bottleneck层来减少计算量,主要是在原有的结构中增加1x1 Conv,即 BN + ReLU + Conv(1x1)(filternum=4k) + Dropout(0.2) + BN + ReLU + Conv(3x3)+ Dropout(0.2) ,称为DenseNet-B结构。其中 1x1 Conv得到 4k 个特征图它起到的作用是降低特征数量,从而提升计算效率。如下图:
在这里插入图片描述
在这里插入图片描述

还有一个DenseBlock-BC就不多做介绍了,和DenseBlock-B的设计是一样的思路。结构:BN + ReLU + Conv(1x1)(filternum=4k) + Dropout(0.2) + BN + ReLU + Conv(3x3)+ Dropout(0.2)

2.2、Transition

对于Transition层,它主要是连接两个相邻的DenseBlock,并且降低特征图大小。Transition层包括一个1x1的卷积和2x2的AvgPooling,DenseNet和DenseNet-B的Transition结构都为 BN + ReLU + Conv(1x1)(filternum=m) + Dropout(0.2) + Pooling(2x2)。Transition层还可以起到压缩模型的作用。假定Transition的上接DenseBlock得到的特征图channels数为 m ,Transition层可以产生? θ m ? 个特征(通过卷积层),其中θ ∈ ( 0 , 1 ] 是压缩系数(compression rate)。当 θ = 1时,特征个数经过Transition层没有变化,即无压缩。而当压缩系数小于1时,这种结构称为DenseNet-C,paper中的 θ = 0.5,即:BN + ReLU + Conv(1x1)( filternum=θ\thetaθm, 其中0<θ<10<\theta<10<θ<1, paper中 θ=0.5\theta=0.5θ=0.5) + Dropout(0.2) + Pooling(2x2)。对于使用bottleneck层的DenseBlock结构和压缩系数小于1的Transition组合结构称为DenseNet-BC。

2.3、总结

paper中提出了三种DenseNet网络架构,分别是DenseNet、DenseNet-B、DenseNet-BC。

  1. DenseNet:
    dense block: BN + ReLU + Conv(3x3) + Dropout(0.2)
    transition: BN + ReLU + Conv(1x1)(filternum=m) + Dropout(0.2) + Pooling(2x2)
  2. DenseNet-B
    dense block: BN + ReLU + Conv(1x1)(filternum=4k) + BN + ReLU + Conv(3x3)+ Dropout(0.2)
    transition: BN + ReLU + Conv(1x1)(filternum=m) + Dropout(0.2) + Pooling(2x2)
  3. DenseNet-BC
    dense block: BN + ReLU + Conv(1x1)(filternum=4k) + BN + ReLU + Conv(3x3)+ Dropout(0.2)
    transition: BN + ReLU + Conv(1x1)(filternum=θ\thetaθm, 其中0<θ<10<\theta<10<θ<1, paper中 θ=0.5\theta=0.5θ=0.5) + Dropout(0.2) + Pooling(2x2)

如前所示,DenseNet的网络结构主要由DenseBlock和Transition组成,结构图如下:
在这里插入图片描述
论文里整体的网络配置如下表:
在这里插入图片描述

三、ResNet与DenseNet的比较

DenseNet 与 ResNet的区别主要是feature map融合的方式不同。

  1. ResNet中feature map融合的方式是Sum,即要求要进行特征融合所有feature map在C, H, W三个维度完全相同,然后所有feature map在x相应位置上直接元素相加,相加后新生成的feature map的shape和原来的shape相同。用公式可表示为:
    Xl=Hl(Xl?1)+Xl?1X_l = H_l(X_{l-1})+X_{l-1}Xl?=Hl?(Xl?1?)+Xl?1?
    其中Xl?1X_{l-1}Xl?1?表示第 l?1l-1l?1 层的输出;Hl(Xl?1)H_l(X_{l-1})Hl?(Xl?1?)表示第lll层的一系列非线性变换(通常为BN、ReLU、Conv、Pooling等组合而成);XlX_{l}Xl?表示第 lll 层的输出。具体结构如下图5:
    在这里插入图片描述

思考: ResNet虽然可以将信息从前面层传到后面层,但是ResNet直接元素相加的这种方式,可能会阻碍了信息在网络中的流动能力。所有我们采用下面的这种Concat方式。

  1. DenseNet中feature map融合的方式是Concat,即要求要进行特征融合所有feature map在 H, W两个维度完全相同,所有feature map在channel维度上进行叠加,叠加之后新生成的feature map的channel为所有feature map的channel之和,H、W维度不变。用公式可表示为:
    Xl=Hl([X0,X1,…,Xl?1])X_l=H_l([X_0, X_1,…,X_{l-1}])Xl?=Hl?([X0?,X1?,Xl?1?])
    其中X0X_0X0?为模型的输入;X1,…,Xl?1X_1,…,X_{l-1}X1?Xl?1?是各层的输出;[X0,X1,…,Xl?1][X_0, X_1,…,X_{l-1}][X0?,X1?,Xl?1?]就是上面说的concat操作得到的新的feature map;HlH_lHl?为第lll层的非线性操作。具体结构如下图6:

在这里插入图片描述

四、优缺点

优点:

  1. 看网络结构好像网络的参数很多,但是实际上网络参数更少,提高训练效率
  2. 增强了特征在各个层之间的流动(特征重用),因为每一层都与初始输入层还有最后的由loss function得到的梯度直接相连
  3. 进行反向传播的时候,减轻了梯度弥散的问题,使模型不容易过拟合
  4. 由于每层可以直达最后的误差信号,实现了隐式的“deep supervision”。误差信号可以很容易地传播到较早的层,所以较早的层可以从最终分类层获得直接监管(监督)。

缺点:

  1. DenseNet的不足在于由于需要进行多次Concatnate操作,数据需要被复制多次,显存容易增加得很快,需要一定的显存优化技术。有一种更高效的实现,感兴趣的可以参考: Memory-Efficient Implementation of DenseNets.
  2. 另外,DenseNet是一种更为特殊的网络,ResNet则相对一般化一些,因此ResNet的应用范围更广泛。

五、PyTorch实现

这里实现的代码是我从PyTorch官方下载的代码,用到的DenseNet-BC结构,我对上面的网络结构细化了一下如下表:
在这里插入图片描述

注意:每个DenseBlock的输出shape的计算很重要,理解这部分代码是整个DenseNet最精华的部分。具体代码在下面的两个cat操作。

DenseNet-BC代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from tensorboard import summary
from torch import Tensor
from typing import Any, List, Tuple
from collections import OrderedDictclass _DenseLayer(nn.Module):"""DenseBlock中的内部结构 DenseLayer: BN + ReLU + Conv(1x1) + BN + ReLU + Conv(3x3)"""def __init__(self,num_input_features: int,growth_rate: int,bn_size: int,drop_rate: float,memory_efficient: bool = False):""":param num_input_features: 输入channel:param growth_rate: 论文中的 k = 32:param bn_size: 1x1卷积的filternum = bn_size * k 通常bn_size=4:param drop_rate: dropout 失活率:param memory_efficient: Memory-efficient版的densenet 默认是不使用的"""super(_DenseLayer, self).__init__()self.add_module("norm1", nn.BatchNorm2d(num_input_features))self.add_module("relu1", nn.ReLU(inplace=True))self.add_module("conv1", nn.Conv2d(in_channels=num_input_features,out_channels=bn_size * growth_rate,kernel_size=1,stride=1,bias=False))self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))self.add_module("relu2", nn.ReLU(inplace=True))self.add_module("conv2", nn.Conv2d(bn_size * growth_rate,growth_rate,kernel_size=3,stride=1,padding=1,bias=False))self.drop_rate = drop_rateself.memory_efficient = memory_efficientdef bn_function(self, inputs: List[Tensor]) -> Tensor: # inputs: [16,64,56,56](输入) [16,32,56,56] [16,32,56,56]# 以DenseBlock=1为例 每一次都会叠加上一个DenseLayer的输出(channel=32) 64+32*5=224# [16,64,56,56] [16,96,56,56] [16,128,56,56] [16,160,56,56] [16,192,56,56] [16,224,56,56]# 注意:这个concat和之后的DenseBlock中的concat非常重要,理解这两句就能理解DenseNet中密集连接的精髓concat_features = torch.cat(inputs, 1) # 以DenseBlock=1为例 bottleneck_output: 一直都是生成[16,128,56,56]bottleneck_output = self.conv1(self.relu1(self.norm1(concat_features)))return bottleneck_output@staticmethoddef any_requires_grad(inputs: List[Tensor]) -> bool:"""判断是否需要更新梯度(training)"""for tensor in inputs:if tensor.requires_grad:return Truereturn False@torch.jit.unuseddef call_checkpoint_bottleneck(self, inputs: List[Tensor]) -> Tensor:"""torch.utils.checkpoint: 用计算换内存(节省内存)。 详情可看: https://arxiv.org/abs/1707.06990torch.utils.checkpoint并不保存中间激活值,而是在反向传播时重新计算它们。 它可以应用于模型的任何部分。具体而言,在前向传递中,function将以torch.no_grad()的方式运行,即不存储中间激活值 相反,前向传递将保存输入元组和function参数。在反向传播时,检索保存的输入和function参数,然后再次对函数进行正向计算,现在跟踪中间激活值,然后使用这些激活值计算梯度。"""def closure(*inp):return self.bn_function(inp)return cp.checkpoint(closure, *inputs)def forward(self, inputs: Tensor) -> Tensor:if isinstance(inputs, Tensor):  # 确保inputs的格式满足要求prev_features = [inputs]else:prev_features = inputs# 判断是否使用memory_efficient的densenet and 是否需要更新梯度(training)# torch.utils.checkpoint不适用于torch.autograd.grad(),而仅适用于torch.autograd.backward()if self.memory_efficient and self.any_requires_grad(prev_features):# torch.jit 模式下不合适用memory_efficientif torch.jit.is_scripting():raise Exception("memory efficient not supported in JIT")# 调用efficient densenet 思路:用计算换显存bottleneck_output = self.call_checkpoint_bottleneck(prev_features)else:# 调用普通的densenet# 以DenseBlock=1为例 bottleneck_output: 一直都是生成[16,128,56,56]bottleneck_output = self.bn_function(prev_features)# 以DenseBlock=1为例 new_features: 一直都是生成[16,32,56,56]new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))if self.drop_rate > 0:new_features = F.dropout(new_features,p=self.drop_rate,training=self.training)return new_featuresclass _DenseBlock(nn.ModuleDict):_version = 2def __init__(self,num_layers: int,num_input_features: int,bn_size: int,growth_rate: int,drop_rate: float,memory_efficient: bool = False):""":param num_layers: 该DenseBlock中DenseLayer的个数:param num_input_features: 该DenseBlock的输入Channel,每经过一个DenseBlock都会进行叠加叠加方式:num_features = num_features + num_layers * growth_rate:param bn_size: 1x1卷积的filternum = bn_size*k 通常bn_size=4:param growth_rate: 指的是论文中的k 小点比较好 论文中是32:param drop_rate: dropout rate after each dense layer:param memory_efficient: If True, uses checkpointing. Much more memory efficient"""super(_DenseBlock, self).__init__()for i in range(num_layers):layer = _DenseLayer(num_input_features + i * growth_rate,growth_rate=growth_rate,bn_size=bn_size,drop_rate=drop_rate,memory_efficient=memory_efficient)self.add_module("denselayer%d" % (i + 1), layer)def forward(self, init_features: Tensor) -> Tensor:# 以DenseBlock1 为例子:features=7个List [16,64,56,56](输入) [16,32,56,56]x6# 每一个[16,32,56,56]都是用当前DenseBlock的之前所有的信息concat之后得到的features = [init_features]for name, layer in self.items():  # 遍历该DenseBlock的所有DenseLayernew_features = layer(features)features.append(new_features)  # 将该层从输出append进new_features,传给下一个DenseBlock,在bn_function中append# 获取之前所有层(DenseLayer)的信息 并传给下一个Transition层 # 以DenseBlock=1为例 传给Transition数据为 [16,64+32*6,56,56]return torch.cat(features, 1)   class _Transition(nn.Sequential):def __init__(self,num_input_features: int,num_output_features: int):super(_Transition, self).__init__()self.add_module("norm", nn.BatchNorm2d(num_input_features))self.add_module("relu", nn.ReLU(inplace=True))self.add_module("conv", nn.Conv2d(num_input_features,num_output_features,kernel_size=1,stride=1,bias=False))self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))class DenseNet(nn.Module):"""Densenet-BC"""def __init__(self,growth_rate: int = 32,block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),num_init_features: int = 64,bn_size: int = 4,drop_rate: float = 0,num_classes: int = 1000,memory_efficient: bool = False):""":param growth_rate: 指的是论文中的k 小点比较好 论文中是32:param block_config: 每一个DenseBlock中_DenseLayer的个数:param num_init_features: 整个网络第一个卷积(Conv0)的kernel_size = 64:param bn_size: 1x1卷积的filternum = bn_size*k 通常bn_size=4:param drop_rate: dropout rate after each dense layer 一般为0 不用的:param num_classes: 数据集类别数:param memory_efficient: If True, uses checkpointing. Much more memory efficient"""super(DenseNet, self).__init__()# first conv0+bn0+relu0+pool0self.features = nn.Sequential(OrderedDict([("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),("norm0", nn.BatchNorm2d(num_init_features)),("relu0", nn.ReLU(inplace=True)),("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),]))# num_features:DenseBlock的输入Channel,每经过一个DenseBlock都会进行叠加# 叠加方式:num_features = num_features + num_layers * growth_ratenum_features = num_init_features# each dense blockfor i, num_layers in enumerate(block_config):block = _DenseBlock(num_layers=num_layers,num_input_features=num_features,bn_size=bn_size,growth_rate=growth_rate,drop_rate=drop_rate,memory_efficient=memory_efficient)self.features.add_module("denseblock%d" % (i + 1), block)num_features = num_features + num_layers * growth_rate# Transition个数 = DenseBlock个数-1 DenseBlock4后面是没有Transition的# 经过Transition channel直接 // 2if i != len(block_config) - 1:trans = _Transition(num_input_features=num_features,num_output_features=num_features // 2)self.features.add_module("transition%d" % (i + 1), trans)num_features = num_features // 2# finnal batch normself.features.add_module("norm5", nn.BatchNorm2d(num_features))# fc layerself.classifier = nn.Linear(num_features, num_classes)# init weightsfor m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x: Tensor) -> Tensor:features = self.features(x)out = F.relu(features, inplace=True)out = F.adaptive_avg_pool2d(out, (1, 1))out = torch.flatten(out, 1)out = self.classifier(out)return outdef densenet121(**kwargs: Any) -> DenseNet:# Top-1 error: 25.35%# 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth'return DenseNet(growth_rate=32,  # k=32block_config=(6, 12, 24, 16),  # 每一个 DenseBlock 中 _DenseLayer的个数num_init_features=64,  # 第一个Dense Block的输入channel**kwargs)def densenet169(**kwargs: Any) -> DenseNet:# Top-1 error: 24.00%# 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth'return DenseNet(growth_rate=32,block_config=(6, 12, 32, 32),num_init_features=64,**kwargs)def densenet201(**kwargs: Any) -> DenseNet:# Top-1 error: 22.80%# 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth'return DenseNet(growth_rate=32,block_config=(6, 12, 48, 32),num_init_features=64,**kwargs)def densenet161(**kwargs: Any) -> DenseNet:# Top-1 error: 22.35%# 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth'return DenseNet(growth_rate=48,block_config=(6, 12, 36, 24),num_init_features=96,**kwargs)if __name__ == '__main__':"""测试模型"""device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = densenet121()# summary(model, (3, 224, 224))model_pre_weight_path = "./weights/densenet121_pre.pth"model.load_state_dict(torch.load(model_pre_weight_path, map_location=device), strict=False)in_channel = model.classifier.in_features  # 得到fc层的输入channelmodel.classifier = nn.Linear(in_channel, 5)print(model)

1、代码中DenseBlock的一个append和两个concat操作是整个程序的核心操作,充分理解了这三句代码。才能理解什么是密集连接(DenseNet)。
2、这个官方的代码并没有使用DropOut,Transition中根本没写,DenseBlock是写了也没调用。

六、实验结果

论文中在cifar10 和 cifar100 上的实验结果:
在这里插入图片描述

总结

感觉这篇文章的精髓就在于特征重用,使得特征可以在网络中持续传播,可能网络因此就学比较好。而且所提出的网络连接方式(concat)也有助于缓解梯度消失的问题。而引入的瓶颈层结合增长率又控制了网络的宽度,可以极大地降低网络参数的总量,进而能以较少的计算代价就达到当前最好的水平。最后,由于 DenseNet 的特征重用使得网络能够更紧凑地学习,加上特征冗余的降低,所以它可能会是个很好的视觉任务特征抽取器。

Reference

链接1: link.

链接2: Memory-Efficient Implementation of DenseNets csdn.
链接3: Pytorch checkpoint.

https://arxiv.org/pdf/1608.06993.pdf

  相关解决方案