欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Mxnet (18): 密集连接的网络(DenseNet)

程序员文章站 2024-03-14 20:28:52
...

ResNet极大地改变了有关如何参数化深度网络功能的观点。DenseNet(密集卷积网络)在某种程度上是对此的逻辑扩展,它与ResNet的主要区别见下图:

Mxnet (18): 密集连接的网络(DenseNet)

  • ResNet是相加
  • DenseNet是连结

x → [ x , f 1 ( x ) , f 2 ( [ x , f 1 ( x ) ] ) , f 3 ( [ x , f 1 ( x ) , f 2 ( [ x , f 1 ( x ) ] ) ] ) , … ] x→[x,f_1(x),f_2([x,f_1(x)]),f_3([x,f_1(x),f_2([x,f_1(x)])]),…] x[x,f1(x),f2([x,f1(x)]),f3([x,f1(x),f2([x,f1(x)])]),]

ResNet的主要区别在于,DenseNet里模块B的输出不是像ResNet那样和模块A的输出相加,而是在通道维上连结。这样模块A的输出可以直接传入模块B后面的层。在这个设计里,模块A直接跟模块B后面的所有层连接在了一起。之所以起名为密集连接,是因为变量之间的依存关系图变得非常密集。

构成DenseNet的主要组件是密集块(dense block)和过渡层(transition layer)。前者定义输入和输出的连接方式,而后者控制通道的数量,以免太大。

1. 密集块(dense block)

DenseNet使用了ResNet改良版的“批量归一化、**和卷积”结构.

from d2l import mxnet as d2l
from mxnet import np, npx, init, gluon, autograd
from mxnet.gluon import nn
import plotly.graph_objs as go
npx.set_np()

ctx = npx.gpu() if npx.num_gpus() > 0 else npx.cpu()

def conv_block(num_channels):
    block = nn.Sequential()
    block.add(
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.Conv2D(num_channels, kernel_size=3, padding=1)
    )
    return block
  • 密集块由多个conv_block组,成每块使用相同的输出通道数。但在正向传播的时候,我们将每块的输入和输出在通道维上连结。
class DenseBlock(nn.Block):
    def __init__(self, num_convs, num_channels, **kwargs):
        super().__init__(**kwargs)
        self.net = nn.Sequential()
        for _ in range(num_convs):
            self.net.add(conv_block(num_channels))
            
    def forward(self, X):
        for block in self.net:
            Y = block(X)
            # axis=1: 在通道维上将输入和输出连结
            X = np.concatenate((X, Y), axis=1)
        return X
  • 下面的示例中,我们定义了一个DenseBlock实例,该实例具有2个卷积块,每个卷积块有10个输出通道。当使用具有3个通道的输入时,我们将获得带有 3 + 2 × 10 = 23 3+2×10=23 3+2×10=23 通道。卷积块通道数控制着输出通道数相对于输入通道数的增长。这也称为增长率。
blk = DenseBlock(2, 10)
blk.initialize()
X = np.random.uniform(size=(4, 3, 8, 8))
Y = blk(X)
Y.shape

# (4, 23, 8, 8)

2.过渡层

每个密集块都会增加通道数量,因此添加过多通道会导致模型过于复杂。甲过渡层被用来控制模型的复杂性。通过使用 1 × 1 1×1 1×1 卷积层降低通道数,并通过步幅为2的平均池化层将高度和宽度减半。

def transition_block(num_channels):
    block = nn.Sequential()
    block.add(
        nn.BatchNorm(), 
        nn.Activation('relu'),
        nn.Conv2D(num_channels, kernel_size=1),
        nn.AvgPool2D(pool_size=2, strides=2)
    )
    return block
  • 通过10个通道的过渡层应用于密集块的输出。这样可以将输出通道的数量减少到10,并使高度和宽度减半。
blk = transition_block(10)
blk.initialize()
blk(Y).shape

# (4, 10, 4, 4)

3.DenseNet 模型

DenseNet首先使用与ResNet中相同的单个卷积层和最大池化层。

DenseNet = nn.Sequential()
DenseNet.add(
    nn.Conv2D(64, kernel_size=7, strides=2, padding=1),
    nn.BatchNorm(),
    nn.Activation('relu'),
    nn.MaxPool2D(pool_size=3, strides=2, padding=1)
)

类似于ResNet接下来使用的4个残差块,DenseNet使用的是4个稠密块。同ResNet一样,我们可以设置每个稠密块使用多少个卷积层。这里我们设成4,从而与上一节的ResNet-18保持一致。稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。

ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。这里我们则使用过渡层来减半高和宽,并减半通道数。

# num_channels为当前的通道数
def add_blocks(net, num_channels = 64, growth_rate = 32, num_convs_in_dense_blocks = [4, 4, 4, 4]):

    for i, num_convs in enumerate(num_convs_in_dense_blocks):
        net.add(DenseBlock(num_convs, growth_rate))
        # 获取上一个dense block的输出通道数
        num_channels += num_convs * growth_rate
        # 在dense block块之间加入通道数减半的过渡层,夹心饼干结构
        if i != len(num_convs_in_dense_blocks) - 1:
            num_channels //= 2
            net.add(transition_block(num_channels))

add_blocks(DenseNet)

同ResNet一样,最后接上全局池化层和全连接层来输出。

DenseNet.add(
    nn.BatchNorm(), 
    nn.Activation('relu'), 
    nn.GlobalAvgPool2D(),
    nn.Dense(10)
)

4.训练

代码没变,只更换了模型:

def get_workers(num):
    # windows系统不能使用多线程转换
    return 0 if __import__('sys').platform.startswith('win') else num

def loader(data, batch_size, shuffle=True, workers = 6):
    return gluon.data.DataLoader(data,batch_size, shuffle=shuffle,
                                   num_workers=get_workers(workers))

def load_data(batch_size, resize=None):
    
    dataset = gluon.data.vision
    trans = [dataset.transforms.Resize(resize)] if resize else []
    trans.append(dataset.transforms.ToTensor())
    trans = dataset.transforms.Compose(trans)
    mnist_train = dataset.FashionMNIST(train=True).transform_first(trans)
    mnist_test = dataset.FashionMNIST(train=False).transform_first(trans)
    return loader(mnist_train, batch_size), loader(mnist_test, batch_size, False)    


def accuracy(y_hat, y): 
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.astype(y.dtype) == y
    return float(cmp.sum())

def train_epoch(net, train_iter, loss, updater):
    
    l_sum = acc_rate = total = 0
    
    if isinstance(updater, gluon.Trainer):
        updater = updater.step
        
    for X,y in train_iter:
        X = X.as_in_ctx(ctx)
        y = y.as_in_ctx(ctx)
        with autograd.record():
            pre_y = net(X)
            l = loss(pre_y, y)
        l.backward()
        updater(y.size)
        l_sum += float(l.sum())
        acc_rate += accuracy(pre_y, y)
        total += y.size
    return l_sum/total, acc_rate/total

def evaluate_accuracy(net, data_iter):  

    match_num = total_num = 0
    for X, y in data_iter:
        X = X.as_in_ctx(ctx)
        y = y.as_in_ctx(ctx)
        match_num += accuracy(net(X), y)
        total_num += y.size
    return match_num / total_num

import time
def train(net, train_iter, test_iter, epochs, lr):
    
    net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier())
    loss = gluon.loss.SoftmaxCrossEntropyLoss()
    trainer = gluon.Trainer(net.collect_params(), 'sgd',  {'learning_rate': lr})
    l_lst, acc_lst, test_acc_lst = [], [], []
    timer = 0
    print("----------------start------------------")
    for epoch in range(epochs):
        start = time.time()
        l, acc = train_epoch(net, train_iter, loss, trainer)
        timer += time.time()-start
        test_acc = evaluate_accuracy(net, test_iter)
        print(f'[epoch {epoch+1}] loss {l:.3f}, train acc {acc:.3f}, ' f'test acc {test_acc:.3f}')
        l_lst.append(l)
        acc_lst.append(acc)
        test_acc_lst.append(test_acc)
    print(f'loss {l:.3f}, train acc {acc:.3f}, test acc {test_acc:.3f}')
    print(f'{timer:.1f} sec, on {str(ctx)}')
    draw_graph([l_lst, acc_lst, test_acc_lst])
    

def draw_graph(result):
    data = []
    colors = ['aquamarine', 'orange', 'hotpink']
    names = ['train loss', 'train acc', 'test acc']
    symbols = ['circle-open', 'cross-open', 'triangle-up-open']
    for i, info in enumerate(result):
        trace = go.Scatter(
            x = list(range(1, num_epochs+1)),
            y = info,
            mode = 'lines+markers',
            name = names[i],
            marker = {
                'color':colors[i],
                'symbol':symbols[i],
            },
        )
        data.append(trace)
    fig = go.Figure(data = data)
    fig.update_layout(xaxis_title='epochs', width=800, height=480)
    fig.show()

同样,在Fashion-MNIST数据集上训练ResNet。因为模型比较复杂为了简化使用96的size,由于显存比较捉襟,使用64的batch_size。

lr, num_epochs, batch_size = 0.1, 10, 64
train_iter, test_iter = load_data(batch_size, resize=96)
train(DenseNet, train_iter, test_iter, num_epochs, lr)
  • 准确率和ResNet还是比较接近的
    Mxnet (18): 密集连接的网络(DenseNet)

Mxnet (18): 密集连接的网络(DenseNet)

5. 预测

训练完成的模型通过输入一些数据进行预测,试试效果

import plotly.express as px
from plotly.subplots import make_subplots
def get_fashion_mnist_labels(labels): 
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

def show_images(imgs, num_rows, num_cols, titles=None): 
    colorscales = px.colors.named_colorscales()
    fig = make_subplots(num_rows, num_cols, subplot_titles=titles)
    for i, img in enumerate(imgs):
        fig.add_trace(go.Heatmap(z=img.asnumpy()[::-1], showscale=False, colorscale=colorscales[i+3]), 1, i+1)
        fig.update_xaxes(visible=False,row=1, col=i+1)
        fig.update_yaxes(visible=False, row=1, col=i+1)
    fig.update_layout(height=280)
    fig.show()

def predict(net, test_iter, stop, shape=(28,28) ,n=8):
    for i,(X,y) in enumerate(test_iter):
        if (i==stop) :
            break
    X,y = X.as_in_ctx(ctx), y.as_in_ctx(ctx)
    trues = get_fashion_mnist_labels(y)
    preds = get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [f"true: {t} <br> pre: {p}" for t, p in zip(trues, preds)]
    show_images(X[:n].reshape((n, shape[0], shape[1])), 1, n, titles=titles[:n])


predict(DenseNet, test_iter, 20, (96,96))

Mxnet (18): 密集连接的网络(DenseNet)

6.参考

https://d2l.ai/chapter_convolutional-modern/densenet.html

7.代码

github