DenseNet模型原理及代码
1. DenseNet模型原理
- CNN史上的一个里程碑事件是ResNet模型的出现,ResNet可以训练出更深的CNN模型,从而实现更高的准确度。ResNet模型的核心是通过建立前面层与后面层之间的“短路连接”(shortcut connection),这有助于训练过程中梯度的反向传播,从而能训练出更深的CNN网络。今天我们要介绍的是DenseNet模型,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集/稠密连接(dense connection),它的名称也是由此而来。DenseNet的另一大特色是通过特征在channel上的连接来实现特征重用(feature reuse)。这些特点让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能,DenseNet也因此斩获CVPR 2017的最佳论文奖。
DenseNet 的想法很大程度上源于我们去年发表在 ECCV 上的一个叫做随机深度网络(Deep networks with stochastic depth)工作。当时我们提出了一种类似于 Dropout 的方法来改进ResNet。我们发现在训练过程中的每一步都随机地「扔掉」(drop)一些层,可以显著的提高 ResNet 的泛化性能。这个方法的成功至少带给我们两点启发:
* 首先,它说明了神经网络其实并不一定要是一个递进层级结构,也就是说网络中的某一层可以不仅仅依赖于紧邻的上一层的特征,而可以依赖于更前面层学习的特征。想像一下在随机深度网络中,当第 l 层被扔掉之后,第 l+1 层就被直接连到了第 l-1 层;当第 2 到了第 l 层都被扔掉之后,第 l+1 层就直接用到了第 1 层的特征。因此,随机深度网络其实可以看成一个具有随机密集连接的 DenseNet。
* 其次,我们在训练的过程中随机扔掉很多层也不会破坏算法的收敛,说明了 ResNet 具有比较明显的冗余性,网络中的每一层都只提取了很少的特征(即所谓的残差)。实际上,我们将训练好的 ResNet 随机的去掉几层,对网络的预测结果也不会产生太大的影响。既然每一层学习的特征这么少,能不能降低它的计算量来减小冗余呢?
DenseNet 的设计正是基于以上两点观察。我们让网络中的每一层都直接与其前面层相连,实现特征的重复利用;同时把网络的每一层设计得特别「窄」,即只学习非常少的特征图(最极端情况就是每一层只学习一个特征图),达到降低冗余性的目的。这两点也是 DenseNet 与其他网络最主要的不同。需要强调的是,第一点是第二点的前提,没有密集连接,我们是不可能把网络设计得太窄的,否则训练会出现欠拟合(under-fitting)现象,即使 ResNet 也是如此。
-
相比ResNet,DenseNet提出了一个更激进的密集连接机制:即互相连接所有的层,具体来说就是每个层都会接受其前面所有层作为其额外的输入。下图为plain network、ResNet和DenseNet的对比。可以看到,ResNet是每个层与前面的某层(一般是2~3层)短路连接在一起,连接方式是通过元素级相加。而在DenseNet中,每个层都会与前面所有层在channel维度上连接(concat)在一起(这里各个层的特征图大小是相同的,后面会有说明),并作为下一层的输入。对于一个层的网络,ResNet共包含个连接,而DenseNet共包含个连接,相比ResNet,这是一种密集连接。而且DenseNet是直接concat来自不同层的特征图,这可以实现特征重用,提升效率,这一特点是DenseNet与ResNet最主要的区别。
如果用公式表示的话,plain network在层的输出为 ;在ResNet中,增加了来自上一层输入的identity函数 ;而在DenseNet中,会连接前面所有层作为输入 。 -
DenseNet的前向传播方式如下图所示,可以更直观的理解他的密集连接方式。
2. DenseNet具体结构
-
CNN网络一般要经过Pooling或者stride>1的Conv来降低特征图的尺寸,而DenseNet的密集连接方式需要特征图大小保持一致。为了解决这个问题,DenseNet网络中使用DenseBlock+Transition的结构,其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。而Transition模块是连接两个相邻的DenseBlock,并且通过Pooling使特征图大小降低。下图给出了DenseNet的网络结构,它共包含3个DenseBlock,各个DenseBlock之间通过Transition连接在一起。
-
在DenseBlock中,各个层的特征图大小一致,可以在channel维度上连接。值得注意的一点是,与ResNet不同,所有DenseBlock中各个层卷积之后均输出 个特征图,即得到的特征图的channel数为 ,或者说采用 个卷积核。 在DenseNet称为growth rate,这是一个超参数。假定输入层的特征图的channel数为 ,那么第 层输入的channel数为 ,因此随着层数增加,DenseBlock的输入会非常多,不过这是由于特征重用所造成的,每个层仅有 个特征是自己独有的。
-
由于后面层的输入会非常大,DenseBlock内部可以采用bottleneck层来减少计算量,主要是利用1x1 Conv,如图7所示,即BN+ReLU+1x1Conv+BN+ReLU+3x3Conv(pre-activation),称为DenseNet-B结构。其中1x1 Conv得到 个特征图它起到的作用是降低特征数量,从而提升计算效率
-
对于Transition层,它主要是连接两个相邻的DenseBlock,并且降低特征图大小。Transition层包括一个1x1的卷积和2x2的AvgPooling,结构为BN+ReLU+1x1Conv+2x2 AvgPooling。另外,Transition层可以起到压缩模型的作用。假定Transition的上接DenseBlock得到的特征图channels数为 ,Transition层可以产生 个特征(通过卷积层),其中 是压缩系数(compression rate)。当 时,特征个数经过Transition层没有变化,即无压缩,而当压缩系数小于1时,这种结构称为DenseNet-C,文中使用 。对于使用bottleneck层的DenseBlock结构和压缩系数小于1的Transition组合结构称为DenseNet-BC。
-
对于ImageNet数据集,图片输入大小为 ,网络结构采用包含4个DenseBlock的DenseNet-BC,其首先是一个stride=2的7x7卷积层(卷积核数为 ),注意此处的卷积层的内部排列为正常的Conv+BN+ReLU。然后是一个stride=2的3x3 MaxPooling层,后面才进入DenseBlock。ImageNet数据集所采用的网络配置如下表所示:
-
部分实验结果对比如下图
-
DenseNet的不足
3. DenseNet实现代码
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Bottleneck(nn.Module):
def __init__(self, in_planes, growth_rate):
super(Bottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1)
self.bn2 = nn.BatchNorm2d(4*growth_rate)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1)
def forward(self, x):
out = self.conv1(self.relu1(self.bn1(x)))
out = self.conv2(self.relu2(self.bn2(out)))
out = torch.cat([out,x], 1)
return out
class Transition(nn.Module):
def __init__(self, in_planes, out_planes):
super(Transition, self).__init__()
self.bn = nn.BatchNorm2d(in_planes)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1)
def forward(self, x):
out = self.conv(self.relu(self.bn(x)))
out = F.avg_pool2d(out, 2)
return out
class DenseNet(nn.Module):
def __init__(self, nblocks, growth_rate, reduction, num_classes):
super(DenseNet, self).__init__()
self.growth_rate = growth_rate
num_planes = 2 * growth_rate
self.basic_conv = nn.Sequential(
nn.Conv2d(3, num_planes, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(num_planes),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.dense1 = self._make_dense_layers(num_planes, nblocks[0])
num_planes += nblocks[0] * growth_rate
out_planes = int(math.floor(num_planes * reduction))
self.trans1 = Transition(num_planes, out_planes)
num_planes = out_planes
self.dense2 = self._make_dense_layers(num_planes, nblocks[1])
num_planes += nblocks[1] * growth_rate
out_planes = int(math.floor(num_planes * reduction))
self.trans2 = Transition(num_planes, out_planes)
num_planes = out_planes
self.dense3 = self._make_dense_layers(num_planes, nblocks[2])
num_planes += nblocks[2] * growth_rate
out_planes = int(math.floor(num_planes * reduction))
self.trans3 = Transition(num_planes, out_planes)
num_planes = out_planes
self.dense4 = self._make_dense_layers(num_planes, nblocks[3])
num_planes += nblocks[3] * growth_rate
self.gap = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(num_planes, num_classes)
def _make_dense_layers(self, in_planes, nblock):
layers = []
for i in range(nblock):
layers.append(Bottleneck(in_planes, self.growth_rate))
in_planes += self.growth_rate
return nn.Sequential(*layers)
def forward(self, x):
out = self.basic_conv(x)
out = self.trans1(self.dense1(out))
out = self.trans2(self.dense2(out))
out = self.trans3(self.dense3(out))
out = self.dense4(out)
out = self.gap(out)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def DenseNet121():
return DenseNet([6,12,24,16], growth_rate=32, reduction=0.5, num_classes=1000)
def DenseNet169():
return DenseNet([6,12,32,32], growth_rate=32, reduction=0.5, num_classes=1000)
def DenseNet201():
return DenseNet([6,12,48,32], growth_rate=32, reduction=0.5, num_classes=1000)
def DenseNet265():
return DenseNet([6,12,64,48], growth_rate=32, reduction=0.5, num_classes=1000)
net = DenseNet121()
x = torch.randn(1,3,224,224)
y = net(x)
print(y.size())
上一篇: Java集合类汇总
下一篇: 对Slim 框架进行总结