Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读
pytorch 代码:https://github.com/princeton-vl/pytorch_stacked_hourglass
论文原理
Summary
论文设计了一个新的框架Stacked Hourglass Network,通过提取和融合多尺度特征,来更好的捕获人体关键点的各种空间关系(spatial relationship)。
Motivation
Hourglass模块设计的初衷就是为了捕捉每个尺度下的信息,因为捕捉像脸,手这些部分的时候需要局部的特征,而最后对人体姿态进行预测的时候又需要整体的信息。(感受野较大的feature map可以捕捉到更高阶的特征和全局上下文)
-
Hourglass Layer:The network captures and consolidates information across all scales of the image. 通过一步步的下采样获取不同尺度的feature map,再通过上采样和skip layers对不同尺度下的特征进行融合;与其他结构相比有更加对称的结构
-
multiple iterative stages:Hourglass层可以在局部和全局上下文信息中提取feature,并生成预测。之后用多个Hourglass迭代,可以对高阶特征进行多次处理,来进行进一步评估,并重新估计高阶的空间关系(spatial relationship)
-
Intermediate Supervision:在每个stage都计算loss, 避免深度多阶段网络中常见的梯度消失问题。
Model
本论文网络的设计深刻的体现出从 Block 到 Layer 最后形成网络的方法。
Resiual Model
论文使用Residual模块来提取特征,本文中残差模块不改变输入的高度和高度,仅改变Channel,而且保证输出Channel始终为256。
实际上残差模块是BottleNeck。作者进行一系列的尝试,从卷积核较大的标准卷积到一些新方法,像Residual,Inception模块。
Hourglass
Hourglass模块由上下两路构成,下路–获得较小尺度的特征
上路:利用Residual提取原尺度的特征,与下路相加后便得到融合的多尺度的特征
上采样:最近邻插值
下路:获得较小尺度的特征
下采样:max pooling (2*2)
第一个Residual:用来获取下采样之后的特征
第二个Residual:可以换成Hourglass层,得到高阶Hourglass
第三个Residual:我猜测是用来保证对称结构,hhhh
多阶Hourglass
原文使用了8个4阶Hourglass。。。
输入的feature map,经过4次下采样后获得的feature map,然后进行4次上采样。在整个过程中,Channel=256。
完整网络
Stacked Hourglass:最终一共使用了8个沙漏网络。每个沙漏网络的输入都为 。
输入图片大小为,一开始经过一次 stride=2 的卷积(padding=(kernel_size-1)//2),紧接着跟随一个residual module和 max pooling将像素值从128下降到64(减少hourglass内部计算量)。其中所有的残差模块输出256个特征图。
单个 Hourglass 能够做到提取特征图在不同尺度上的信息,但是仍然不能在预测时(显式地)考虑不同关键点之间的关系。所以才需要进一步设计 Stacked Hourglass。
损失函数
每个关键点的 Ground Truth 定义为以该关键点为峰值位置的 2D 高斯函数。Loss Function 定义为 Ground Truth 与预测得到的 heatmap 之间的均方误差。
Intermediate Supervision
中继监督的思想在更早的网络里就已经被提出了。GoogLeNet V1 就在网络中部和中后部额外设置了全连接层分类支路,和最终的输出一样,这些支路对 Loss 有一定的贡献。当时的想法是,由于深层网络的多次下采样操作,一定尺度上的特征信息会丢失,所以才设计了这些中继监督位点。与 Stacked Hourglass 不同,GoogLeNet 中的这些中继监督位点的输出并没有再返回网络里。
x输入Hourglass后得到(B, 256,64,64)的输出值,通过用于预测的卷积(蓝框部分)改变Channel,使其与GroundTruth一致,即得到预测(B, 16, 64, 64),该预测可以计算Loss。最后将预测、Hourglass的输出值,以及输入值融合\相加(通过卷积保证Channel一致),送到下一个Hourglass。
直观上,将中继监督的输出重新返回网络中,起到了一种“让网络对当前的特征和预测结果进行再评估”的作用。通过重复地将多个 Hourglass 和中继监督串联,网络能够显式地学习每个预测目标之间的关系,越靠后的预测越能够结合所有关键点的位置信息,做出更准确的关键点位置预测。有了这种特性,网络就基本不会预测出一些从解剖学上不成立的人体姿势。
按元素相加虽然直观上看比较奇怪,但实际上,concatenate到一起之后,再通过一次卷积降维,那次卷积的最后其实也是按元素相加的操作。所以这里直接按元素相加,可以当作之前的卷积层已经得到了可相加的特征,这并没有什么不妥。
GroundTruth
每个关键点的 Ground Truth 定义为以该关键点为峰值位置的 2D 高斯函数。
首先介绍MPII数据集。该数据集主要用于单个人的姿态估计,但它确实为同一图像中的多个人提供关节注释。对于每个人,它给出了16个关节的坐标,比如左脚踝或右肩膀。
关于GroundTruth的另一件重要的事情是高斯分布。当我们生成GroundTruth的heatmap时,我们不只是为关节坐标分配1,并为所有其他像素分配0。这将使GroundTruth过于稀疏,难以了解。如果模型预测只差几个像素,也是值得鼓励的。
用高斯函数对关键点处理,使其中心值最大,中心周围区域值逐渐减小。左图是单个关键点的heatmap,右图是把所有16个关节放在一张heatmap中。
预测
与直接回归相比,使用热图的一个缺点是粒度(granularity)。例如,使用输入,我们将得到一个的热图来表示关键点位置。四倍的缩小比例似乎不是很糟糕。然而,我们通常首先将较大的图像(如)调整输入。在这种情况下,64x64的热图太粗糙了。为了缓解这个问题,研究人员提出了一个有趣的想法。我们不只是使用最大值的像素,我们还考虑了相邻的最大值像素。由于某个相邻像素也很高,因此它推断实际的关键点位置可能是朝向相邻像素的方向。听起来很熟悉,对吧?这很像梯度下降法,它也指向最优解。
消融实验
首先设计了几组网络,来讨论中间监督和stacked Hourglass
为了探索stacked Hourglass设计的效果,我们必须证明性能的变化是框架的设计,而不是由于更大、更深的网络。
在图9中比较了2堆叠、4堆叠和8堆叠网络的验证精度,它们有相同的参数,都包含中间预测。
代码
建立模型
基本层
from torch import nn
Pool = nn.MaxPool2d
def batchnorm(x):
return nn.BatchNorm2d(x.size()[1])(x)
class Conv(nn.Module):
"""
卷积层(包含BN和ReLU)
参数:inp_dim 输入Channel
out_dim 输出Channel
"""
def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True):
super(Conv, self).__init__()
self.inp_dim = inp_dim
self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True)
self.relu = None
self.bn = None
if relu:
self.relu = nn.ReLU()
if bn:
self.bn = nn.BatchNorm2d(out_dim)
def forward(self, x):
assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
Residual Model
class Residual(nn.Module):
"""
Residual层 (实际上是BottleNeck)
参数:inp_dim 输入Channel
out_dim 输出Channel
"""
def __init__(self, inp_dim, out_dim):
super(Residual, self).__init__()
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm2d(inp_dim)
self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False)
self.bn2 = nn.BatchNorm2d(int(out_dim/2))
self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False)
self.bn3 = nn.BatchNorm2d(int(out_dim/2))
self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False)
self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
if inp_dim == out_dim:
self.need_skip = False
else:
self.need_skip = True
def forward(self, x):
if self.need_skip:
residual = self.skip_layer(x)
else:
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
out += residual
return out
Hourglass Model
class Hourglass(nn.Module):
"""
Residual层不改变输入feature的W、H
下采样主要由 pool 完成
"""
def __init__(self, n, f, bn=None, increase=0):
super(Hourglass, self).__init__()
nf = f + increase
self.up1 = Residual(f, f)
# Lower branch
self.pool1 = Pool(2, 2)
self.low1 = Residual(f, nf)
self.n = n
# Recursive hourglass
if self.n > 1:
self.low2 = Hourglass(n-1, nf, bn=bn)
else:
self.low2 = Residual(nf, nf)
self.low3 = Residual(nf, f)
self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
def forward(self, x):
up1 = self.up1(x)
pool1 = self.pool1(x)
low1 = self.low1(pool1)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
return up1 + up2
完整网络
class UnFlatten(nn.Module):
def forward(self, input):
return input.view(-1, 256, 4, 4)
class Merge(nn.Module):
def __init__(self, x_dim, y_dim):
super(Merge, self).__init__()
self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False)
class PoseNet(nn.Module):
"""
'nstack': 8,
'inp_dim': 256,
'oup_dim': 16,
'num_parts': 16,
'increase': 0,
"""
def __init__(self, nstack, inp_dim, oup_dim, bn=False, increase=0, **kwargs):
super(PoseNet, self).__init__()
self.nstack = nstack
self.pre = nn.Sequential(
Conv(3, 64, 7, 2, bn=True, relu=True),
Residual(64, 128),
Pool(2, 2),
Residual(128, 128),
Residual(128, inp_dim)
)
self.hgs = nn.ModuleList( [
nn.Sequential(
Hourglass(4, inp_dim, bn, increase),
) for i in range(nstack)] )
self.features = nn.ModuleList( [
nn.Sequential(
Residual(inp_dim, inp_dim),
Conv(inp_dim, inp_dim, 1, bn=True, relu=True)
) for i in range(nstack)] )
self.outs = nn.ModuleList( [Conv(inp_dim, oup_dim, 1, relu=False, bn=False) for i in range(nstack)] )
self.merge_features = nn.ModuleList( [Merge(inp_dim, inp_dim) for i in range(nstack-1)] )
self.merge_preds = nn.ModuleList( [Merge(oup_dim, inp_dim) for i in range(nstack-1)] )
self.nstack = nstack
self.heatmapLoss = HeatmapLoss()
def forward(self, imgs):
## our posenet
x = imgs.permute(0, 3, 1, 2) #x of size 1,3,inpdim,inpdim
x = self.pre(x) # x :(B, 256, 64, 64)
combined_hm_preds = [] # (i, B, 16, 64, 64 )
for i in range(self.nstack):
hg = self.hgs[i](x) #(B, 256, 64, 64)
feature = self.features[i](hg) #(B, 256, 64, 64)
preds = self.outs[i](feature) #(B, 256, 16, 16)
combined_hm_preds.append(preds)
if i < self.nstack - 1:
x = x + self.merge_preds[i](preds) + self.merge_features[i](feature) # 将 估计 与 特征 以及 x 融合
return torch.stack(combined_hm_preds, 1) # 形成新的数组 (B, i,16, 64, 64 )
def calc_loss(self, combined_hm_preds, heatmaps):
combined_loss = []
for i in range(self.nstack):
combined_loss.append(self.heatmapLoss(combined_hm_preds[0][:,i], heatmaps))
combined_loss = torch.stack(combined_loss, dim=1)
return combined_loss
损失函数
import torch
class HeatmapLoss(torch.nn.Module):
"""
loss for detection heatmap
"""
def __init__(self):
super(HeatmapLoss, self).__init__()
def forward(self, pred, gt):
l = ((pred - gt)**2)
l = l.mean(dim=3).mean(dim=2).mean(dim=1)
return l ## l of dim bsize: torch.size([Bsize])
参考博客
https://blog.csdn.net/shenxiaolu1984/article/details/51428392
https://blog.csdn.net/u013841196/article/details/81048237