欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读

程序员文章站 2022-07-11 23:46:22
...

pytorch 代码:https://github.com/princeton-vl/pytorch_stacked_hourglass

论文原理

Summary

论文设计了一个新的框架Stacked Hourglass Network,通过提取和融合多尺度特征,来更好的捕获人体关键点的各种空间关系(spatial relationship)。

Motivation

Hourglass模块设计的初衷就是为了捕捉每个尺度下的信息,因为捕捉像脸,手这些部分的时候需要局部的特征,而最后对人体姿态进行预测的时候又需要整体的信息。(感受野较大的feature map可以捕捉到更高阶的特征和全局上下文)

  • Hourglass Layer:The network captures and consolidates information across all scales of the image. 通过一步步的下采样获取不同尺度的feature map,再通过上采样和skip layers对不同尺度下的特征进行融合;与其他结构相比有更加对称的结构

  • multiple iterative stages:Hourglass层可以在局部和全局上下文信息中提取feature,并生成预测。之后用多个Hourglass迭代,可以对高阶特征进行多次处理,来进行进一步评估,并重新估计高阶的空间关系(spatial relationship)

  • Intermediate Supervision:在每个stage都计算loss, 避免深度多阶段网络中常见的梯度消失问题。

Model

本论文网络的设计深刻的体现出从 Block 到 Layer 最后形成网络的方法。

Resiual Model

Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读
论文使用Residual模块来提取特征,本文中残差模块不改变输入的高度和高度,仅改变Channel,而且保证输出Channel始终为256。

实际上残差模块是BottleNeck。作者进行一系列的尝试,从卷积核较大的标准卷积到一些新方法,像Residual,Inception模块。

Hourglass

Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读
Hourglass模块由上下两路构成,下路–获得较小尺度的特征
上路:利用Residual提取原尺度的特征,与下路相加后便得到融合的多尺度的特征
上采样:最近邻插值

下路:获得较小尺度的特征
下采样:max pooling (2*2)
第一个Residual:用来获取下采样之后的特征
第二个Residual:可以换成Hourglass层,得到高阶Hourglass
第三个Residual:我猜测是用来保证对称结构,hhhh

多阶Hourglass
Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读
Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读
原文使用了8个4阶Hourglass。。。
输入64×6464\times64的feature map,经过4次下采样后获得8×88\times8的feature map,然后进行4次上采样。在整个过程中,Channel=256。

完整网络

Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读
Stacked Hourglass:最终一共使用了8个沙漏网络。每个沙漏网络的输入都为 64×6464\times64

输入图片大小为256×256256\times256,一开始经过一次7×77\times7 stride=2 的卷积(padding=(kernel_size-1)//2),紧接着跟随一个residual module和 max pooling将像素值从128下降到64(减少hourglass内部计算量)。其中所有的残差模块输出256个特征图。

单个 Hourglass 能够做到提取特征图在不同尺度上的信息,但是仍然不能在预测时(显式地)考虑不同关键点之间的关系。所以才需要进一步设计 Stacked Hourglass。

损失函数

每个关键点的 Ground Truth 定义为以该关键点为峰值位置的 2D 高斯函数。Loss Function 定义为 Ground Truth 与预测得到的 heatmap 之间的均方误差。

Intermediate Supervision

中继监督的思想在更早的网络里就已经被提出了。GoogLeNet V1 就在网络中部和中后部额外设置了全连接层分类支路,和最终的输出一样,这些支路对 Loss 有一定的贡献。当时的想法是,由于深层网络的多次下采样操作,一定尺度上的特征信息会丢失,所以才设计了这些中继监督位点。与 Stacked Hourglass 不同,GoogLeNet 中的这些中继监督位点的输出并没有再返回网络里。
Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读
x输入Hourglass后得到(B, 256,64,64)的输出值,通过用于预测的1×11\times1卷积(蓝框部分)改变Channel,使其与GroundTruth一致,即得到预测(B, 16, 64, 64),该预测可以计算Loss。最后将预测、Hourglass的输出值,以及输入值融合\相加(通过1×11\times1卷积保证Channel一致),送到下一个Hourglass。

直观上,将中继监督的输出重新返回网络中,起到了一种“让网络对当前的特征和预测结果进行再评估”的作用。通过重复地将多个 Hourglass 和中继监督串联,网络能够显式地学习每个预测目标之间的关系,越靠后的预测越能够结合所有关键点的位置信息,做出更准确的关键点位置预测。有了这种特性,网络就基本不会预测出一些从解剖学上不成立的人体姿势。

按元素相加虽然直观上看比较奇怪,但实际上,concatenate到一起之后,再通过一次卷积降维,那次卷积的最后其实也是按元素相加的操作。所以这里直接按元素相加,可以当作之前的卷积层已经得到了可相加的特征,这并没有什么不妥。

GroundTruth
每个关键点的 Ground Truth 定义为以该关键点为峰值位置的 2D 高斯函数。

首先介绍MPII数据集。该数据集主要用于单个人的姿态估计,但它确实为同一图像中的多个人提供关节注释。对于每个人,它给出了16个关节的坐标,比如左脚踝或右肩膀。
Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读
关于GroundTruth的另一件重要的事情是高斯分布。当我们生成GroundTruth的heatmap时,我们不只是为关节坐标分配1,并为所有其他像素分配0。这将使GroundTruth过于稀疏,难以了解。如果模型预测只差几个像素,也是值得鼓励的。
Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读
用高斯函数对关键点处理,使其中心值最大,中心周围区域值逐渐减小。左图是单个关键点的heatmap,右图是把所有16个关节放在一张heatmap中。
Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读

预测

与直接回归相比,使用热图的一个缺点是粒度(granularity)。例如,使用256×256256\times256输入,我们将得到一个64×6464\times64的热图来表示关键点位置。四倍的缩小比例似乎不是很糟糕。然而,我们通常首先将较大的图像(如720×480720\times480)调整256×256256\times256输入。在这种情况下,64x64的热图太粗糙了。为了缓解这个问题,研究人员提出了一个有趣的想法。我们不只是使用最大值的像素,我们还考虑了相邻的最大值像素。由于某个相邻像素也很高,因此它推断实际的关键点位置可能是朝向相邻像素的方向。听起来很熟悉,对吧?这很像梯度下降法,它也指向最优解。

消融实验

首先设计了几组网络,来讨论中间监督和stacked Hourglass
Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读
为了探索stacked Hourglass设计的效果,我们必须证明性能的变化是框架的设计,而不是由于更大、更深的网络。

在图9中比较了2堆叠、4堆叠和8堆叠网络的验证精度,它们有相同的参数,都包含中间预测。
Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读

代码

建立模型

基本层

from torch import nn
Pool = nn.MaxPool2d

def batchnorm(x):  
    return nn.BatchNorm2d(x.size()[1])(x)
class Conv(nn.Module):
    """
        卷积层(包含BN和ReLU)
        参数:inp_dim 输入Channel
                 out_dim 输出Channel    
    """
    def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True):
        super(Conv, self).__init__()
        self.inp_dim = inp_dim
        self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True)
        self.relu = None
        self.bn = None
        if relu:
            self.relu = nn.ReLU()
        if bn:
            self.bn = nn.BatchNorm2d(out_dim)

    def forward(self, x):
        assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

Residual Model

class Residual(nn.Module):
    """ 
        Residual层 (实际上是BottleNeck)
        参数:inp_dim 输入Channel
              out_dim 输出Channel
    """
    def __init__(self, inp_dim, out_dim):
        super(Residual, self).__init__()
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(inp_dim)
        self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False)
        self.bn2 = nn.BatchNorm2d(int(out_dim/2))
        self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False)
        self.bn3 = nn.BatchNorm2d(int(out_dim/2))
        self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False)
        self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
        if inp_dim == out_dim:
            self.need_skip = False
        else:
            self.need_skip = True

    def forward(self, x):
        if self.need_skip:
            residual = self.skip_layer(x)
        else:
            residual = x
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)
        out += residual
        return out 

Hourglass Model


class Hourglass(nn.Module):
    """
        Residual层不改变输入feature的W、H
        下采样主要由 pool 完成
    """
    def __init__(self, n, f, bn=None, increase=0):
        super(Hourglass, self).__init__()
        nf = f + increase
        self.up1 = Residual(f, f)
        # Lower branch
        self.pool1 = Pool(2, 2)
        self.low1 = Residual(f, nf)
        self.n = n
        # Recursive hourglass
        if self.n > 1:
            self.low2 = Hourglass(n-1, nf, bn=bn)
        else:
            self.low2 = Residual(nf, nf)
        self.low3 = Residual(nf, f)
        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        up1  = self.up1(x)
        pool1 = self.pool1(x)
        low1 = self.low1(pool1)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2  = self.up2(low3)
        return up1 + up2

完整网络

class UnFlatten(nn.Module):
    def forward(self, input):
        return input.view(-1, 256, 4, 4)

class Merge(nn.Module):
    def __init__(self, x_dim, y_dim):
        super(Merge, self).__init__()
        self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False)

class PoseNet(nn.Module):
    """
        'nstack': 8,
        'inp_dim': 256,
        'oup_dim': 16,
        'num_parts': 16,
        'increase': 0,
    """
    def __init__(self, nstack, inp_dim, oup_dim, bn=False, increase=0, **kwargs):
        super(PoseNet, self).__init__()
        
        self.nstack = nstack
        self.pre = nn.Sequential(
            Conv(3, 64, 7, 2, bn=True, relu=True),
            Residual(64, 128),
            Pool(2, 2),
            Residual(128, 128),
            Residual(128, inp_dim)
        )
        
        self.hgs = nn.ModuleList( [
        nn.Sequential(
            Hourglass(4, inp_dim, bn, increase),
        ) for i in range(nstack)] )       
        self.features = nn.ModuleList( [
        
        nn.Sequential(
            Residual(inp_dim, inp_dim),
            Conv(inp_dim, inp_dim, 1, bn=True, relu=True)
        ) for i in range(nstack)] )       
        self.outs = nn.ModuleList( [Conv(inp_dim, oup_dim, 1, relu=False, bn=False) for i in range(nstack)] )
        self.merge_features = nn.ModuleList( [Merge(inp_dim, inp_dim) for i in range(nstack-1)] )
        self.merge_preds = nn.ModuleList( [Merge(oup_dim, inp_dim) for i in range(nstack-1)] )
        self.nstack = nstack
        self.heatmapLoss = HeatmapLoss()

    def forward(self, imgs):
        ## our posenet
        x = imgs.permute(0, 3, 1, 2) #x of size 1,3,inpdim,inpdim
        x = self.pre(x)                      # x :(B, 256, 64, 64)
        combined_hm_preds = []    # (i, B, 16, 64, 64 )
        for i in range(self.nstack):
            hg = self.hgs[i](x)                 #(B, 256, 64, 64)
            feature = self.features[i](hg) #(B, 256, 64, 64)
            preds = self.outs[i](feature)  #(B, 256, 16, 16)
            combined_hm_preds.append(preds)
            if i < self.nstack - 1:
                x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)   # 将 估计 与 特征 以及 x 融合
        return torch.stack(combined_hm_preds, 1) # 形成新的数组  (B, i,16, 64, 64 )

    def calc_loss(self, combined_hm_preds, heatmaps):
        combined_loss = []
        for i in range(self.nstack):
            combined_loss.append(self.heatmapLoss(combined_hm_preds[0][:,i], heatmaps)) 
        combined_loss = torch.stack(combined_loss, dim=1)
        return combined_loss     

损失函数

import torch
class HeatmapLoss(torch.nn.Module):
    """
    loss for detection heatmap
    """
    def __init__(self):
        super(HeatmapLoss, self).__init__()

    def forward(self, pred, gt):
        l = ((pred - gt)**2)
        l = l.mean(dim=3).mean(dim=2).mean(dim=1)
        return l ## l of dim bsize: torch.size([Bsize])
        

参考博客
https://blog.csdn.net/shenxiaolu1984/article/details/51428392
https://blog.csdn.net/u013841196/article/details/81048237