欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

复现Dense Extreme Inception Network(pytorch)

程序员文章站 2022-03-29 22:47:35
...

github地址:https://github.com/xavysp/DexiNed/tree/master/DexiNed-Pytorch
论文地址:https://arxiv.org/abs/1909.01955
数据集:https://www.kaggle.com/xavysp/biped

摘要

这篇paper是基于深度学习的边缘检测算法,受到HED(Holistically-Nested Edge Detection)和Xception 网络的启发。该方法生成人眼可能看到的薄边缘地图,可以用于任何边缘检测任务,无需经过长时间训练或微调过程。
复现Dense Extreme Inception Network(pytorch)
该论文的主要贡献:提出了一种鲁棒的CNN边缘检测架构,简称为DexiNed:Dense Extreme Inception Network for Edge Detection。这个模型是从头开始训练的,没有预先训练过的权重。

模型结构

复现Dense Extreme Inception Network(pytorch)
复现Dense Extreme Inception Network(pytorch)

论文实验结果

复现Dense Extreme Inception Network(pytorch)
复现Dense Extreme Inception Network(pytorch)

相关代码

config.py

class Config(object):
    #dataset
    mean_pixel_values = [104.00699, 116.66877, 122.67892, 137.86]
    img_width = 400
    img_height = 400
    train_root = 'data/BIPED/edges/imgs/train/rgbr/real'
    valid_root = 'data/BIPED/edges/edge_maps/train/rgbr/real'
    valid_output_dir = 'valid_temp'
    # hyper parameters
    batch_size = 2
    num_workers = 0
    num_epochs = 25

    model_output = 'result'

dataset.py

from torch.utils.data import DataLoader, Dataset
import torch
import cv2 as cv
import numpy as np
import os


class BIPEDDataset(Dataset):
    def __init__(self, img_root, mode='train', config=None):
        self.img_root = img_root
        self.mode = mode
        self.imgList = os.listdir(img_root)
        self.config = config
        self.mean_bgr = config.mean_pixel_values[0:3] if len(config.mean_pixel_values) == 4 \
            else config.mean_pixel_values

    def __len__(self):
        return len(self.imgList)

    def __getitem__(self, idx):
        file_name = self.imgList[idx].split('.')[0]
        imgPath = os.path.join(self.img_root, self.imgList[idx])
        labelPath = imgPath.replace('imgs', 'edge_maps').replace('jpg', 'png')

        #load data
        image = cv.imread(imgPath, cv.IMREAD_COLOR)
        label = cv.imread(labelPath, cv.IMREAD_GRAYSCALE)
        image_shape = [image.shape[0], image.shape[1]]
        image, label = self.transform(img=image, gt=label)
        return dict(images=image, labels=label, file_name=file_name, image_shape=image_shape)

    def transform(self, img, gt):

        gt = np.array(gt, dtype=np.float32)
        if len(gt.shape) == 3:
            gt = gt[:, :, 0]
        # gt[gt< 51] = 0 # test without gt discrimination
        gt /= 255.
        # if self.yita is not None:
        #     gt[gt >= self.yita] = 1
        img = np.array(img, dtype=np.float32)
        # if self.rgb:
        #     img = img[:, :, ::-1]  # RGB->BGR
        img -= self.mean_bgr
        # data = []
        # if self.scale is not None:
        #     for scl in self.scale:
        #         img_scale = cv.resize(img, None, fx=scl, fy=scl, interpolation=cv.INTER_LINEAR)
        #         data.append(torch.from_numpy(img_scale.transpose((2, 0, 1))).float())
        #     return data, gt

        img = cv.resize(img, dsize=(self.config.img_width, self.config.img_height))
        gt = cv.resize(gt, dsize=(self.config.img_width, self.config.img_height))
        img = img.transpose((2, 0, 1))
        img = torch.from_numpy(img.copy()).float()
        gt = torch.from_numpy(np.array([gt])).float()
        return img, gt


if __name__=='__main__':
    from config import Config
    cfg = Config()
    root = 'data/BIPED/edges/imgs/train/rgbr/real'
    train_dataset = BIPEDDataset(root, config=cfg)
    train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0)
    for data_batch in train_loader:
        img, label = data_batch['images'], data_batch['labels']
        print(img.size(), label.size(),  data_batch['file_name'])

loss.py

import torch
import torch.nn.functional as F

def _weighted_cross_entropy_loss(preds, edges):
    """ Calculate sum of weighted cross entropy loss. """
    # Reference:
    #   hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
    #   https://github.com/s9xie/hed/issues/7
    mask = (edges > 0.5).float()
    b, c, h, w = mask.shape
    num_pos = torch.sum(mask, dim=[1, 2, 3]).float()  # Shape: [b,].
    num_neg = c * h * w - num_pos                     # Shape: [b,].
    weight = torch.zeros_like(mask)
    weight[edges > 0.5]  = num_neg / (num_pos + num_neg)
    weight[edges <= 0.5] = num_pos / (num_pos + num_neg)
    # Calculate loss.
    losses = F.binary_cross_entropy_with_logits(
        preds.float(), edges.float(), weight=weight, reduction='none')
    loss = torch.sum(losses) / b
    return loss

def weighted_cross_entropy_loss(preds, edges):
    """ Calculate sum of weighted cross entropy loss. """
    # Reference:
    #   hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
    #   https://github.com/s9xie/hed/issues/7
    mask = (edges > 0.5).float()
    b, c, h, w = mask.shape
    num_pos = torch.sum(mask, dim=[1, 2, 3], keepdim=True).float()  # Shape: [b,].
    num_neg = c * h * w - num_pos                     # Shape: [b,].
    weight = torch.zeros_like(mask)
    #weight[edges > 0.5]  = num_neg / (num_pos + num_neg)
    #weight[edges <= 0.5] = num_pos / (num_pos + num_neg)
    weight.masked_scatter_(edges > 0.5,
        torch.ones_like(edges) * num_neg / (num_pos + num_neg))
    weight.masked_scatter_(edges <= 0.5,
        torch.ones_like(edges) * num_pos / (num_pos + num_neg))
    # Calculate loss.
    # preds=torch.sigmoid(preds)
    losses = F.binary_cross_entropy_with_logits(
        preds.float(), edges.float(), weight=weight, reduction='none')
    loss = torch.sum(losses) / b
    return loss

model.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class _DenseLayer(nn.Sequential):
    def __init__(self, input_features, out_features):
        super(_DenseLayer, self).__init__()
        # self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(input_features, out_features,
                        kernel_size=1, stride=1, bias=True)),
        self.add_module('norm1', nn.BatchNorm2d(out_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(out_features, out_features,
                        kernel_size=3, stride=1, padding=1, bias=True)),
        self.add_module('norm2', nn.BatchNorm2d(out_features))
        # double check the norm1 comment if necessary and put norm after conv2

    def forward(self, x):
        x1, x2 = x
        # maybe I should put here a RELU
        new_features = super(_DenseLayer, self).forward(x1) # F.relu()
        return 0.5 * (new_features + x2), x2

class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, input_features, out_features):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(input_features, out_features)
            self.add_module('denselayer%d' % (i + 1), layer)
            input_features = out_features

class UpConvBlock(nn.Module):
    def __init__(self, in_features, up_scale, mode='deconv'):
        super(UpConvBlock, self).__init__()
        self.up_factor = 2
        self.constant_features = 16

        layers = None
        if mode == 'deconv':
            layers = self.make_deconv_layers(in_features, up_scale)
        elif mode == 'pixel_shuffle':
            layers = self.make_pixel_shuffle_layers(in_features, up_scale)
        assert layers is not None, layers
        self.features = nn.Sequential(*layers)

    def make_deconv_layers(self, in_features, up_scale):
        layers = []
        for i in range(up_scale):
            kernel_size = 2 ** up_scale
            out_features = self.compute_out_features(i, up_scale)
            layers.append(nn.Conv2d(in_features, out_features, 1))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.ConvTranspose2d(
                out_features, out_features, kernel_size, stride=2))
            in_features = out_features
        return layers

    def make_pixel_shuffle_layers(self, in_features, up_scale):
        layers = []
        for i in range(up_scale):
            kernel_size = 2 ** (i + 1)
            out_features = self.compute_out_features(i, up_scale)
            in_features = int(in_features / (self.up_factor ** 2))
            layers.append(nn.PixelShuffle(self.up_factor))
            layers.append(nn.Conv2d(in_features, out_features, 1))
            if i < up_scale:
                layers.append(nn.ReLU(inplace=True))
            in_features = out_features
        return layers

    def compute_out_features(self, idx, up_scale):
        return 1 if idx == up_scale - 1 else self.constant_features

    def forward(self, x):
        return self.features(x)

class SingleConvBlock(nn.Module):
    def __init__(self, in_features, out_features, stride):
        super(SingleConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride)
        self.bn = nn.BatchNorm2d(out_features)

    def forward(self, x):
        return self.bn(self.conv(x))

class DoubleConvBlock(nn.Module):
    def __init__(self, in_features, mid_features,out_features=None, stride=1):
        super(DoubleConvBlock, self).__init__()
        if out_features is None:
            out_features = mid_features
        self.conv1 = nn.Conv2d(
            in_features, mid_features, 3, padding=1, stride=stride)
        self.bn1 = nn.BatchNorm2d(mid_features)
        self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_features)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

class DexiNet(nn.Module):
    """ Definition of the DXtrem network. """
    def __init__(self):
        super(DexiNet, self).__init__()
        self.block_1 = DoubleConvBlock(3, 32, 64, stride=2)
        self.block_2 = DoubleConvBlock(64, 128)
        self.dblock_3 = _DenseBlock(2, 128, 256)
        self.dblock_4 = _DenseBlock(3, 256, 512)
        self.dblock_5 = _DenseBlock(3, 512, 512)
        self.dblock_6 = _DenseBlock(3, 512, 256)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.side_1 = SingleConvBlock(64, 128, 2)
        self.side_2 = SingleConvBlock(128, 256, 2)
        self.side_3 = SingleConvBlock(256, 512, 2)
        self.side_4 = SingleConvBlock(512, 512, 1)
        self.side_5 = SingleConvBlock(512, 256, 1)

        self.pre_dense_2 = SingleConvBlock(128, 256, 2) # by me, for left skip block4
        self.pre_dense_3 = SingleConvBlock(128, 256, 1)
        self.pre_dense_4 = SingleConvBlock(256, 512, 1)
        self.pre_dense_5_0 = SingleConvBlock(256, 512, 2)
        self.pre_dense_5 = SingleConvBlock(512, 512, 1)
        self.pre_dense_6 = SingleConvBlock(512, 256, 1)

        self.up_block_1 = UpConvBlock(64, 1)
        self.up_block_2 = UpConvBlock(128, 1)
        self.up_block_3 = UpConvBlock(256, 2)
        self.up_block_4 = UpConvBlock(512, 3)
        self.up_block_5 = UpConvBlock(512, 4)
        self.up_block_6 = UpConvBlock(256, 4)
        self.block_cat = nn.Conv2d(6, 1, kernel_size=1)

    def slice(self, tensor, slice_shape):
        height, width = slice_shape
        return tensor[..., :height, :width]

    def forward(self, x):
        assert len(x.shape) == 4, x.shape
        # Block 1
        block_1 = self.block_1(x)
        block_1_side = self.side_1(block_1)

        # Block 2
        block_2 = self.block_2(block_1)
        block_2_down = self.maxpool(block_2)
        block_2_add = block_2_down + block_1_side
        block_2_side = self.side_2(block_2_add)

        # Block 3
        block_3_pre_dense = self.pre_dense_3(block_2_down)
        block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])
        block_3_down = self.maxpool(block_3)
        block_3_add = block_3_down + block_2_side
        block_3_side = self.side_3(block_3_add)

        # Block 4
        block_4_pre_dense_256 = self.pre_dense_2(block_2_down)
        block_4_pre_dense = self.pre_dense_4(block_4_pre_dense_256 + block_3_down)
        block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])
        block_4_down = self.maxpool(block_4)
        block_4_add = block_4_down + block_3_side
        block_4_side = self.side_4(block_4_add)

        # Block 5
        block_5_pre_dense_512 = self.pre_dense_5_0(block_4_pre_dense_256)
        block_5_pre_dense = self.pre_dense_5(block_5_pre_dense_512 + block_4_down )
        block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])
        block_5_add = block_5 + block_4_side
#        block_5_side = self.side_5(block_5_add)

        # Block 6
        block_6_pre_dense = self.pre_dense_6(block_5)
#        block_5_pre_dense_256 = self.pre_dense_6(block_5_add) # if error uncomment
        block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])

        # upsampling blocks
        height, width = x.shape[-2:]
        slice_shape = (height, width)
        out_1 = self.slice(self.up_block_1(block_1), slice_shape)
        out_2 = self.slice(self.up_block_2(block_2), slice_shape)
        out_3 = self.slice(self.up_block_3(block_3), slice_shape)
        out_4 = self.slice(self.up_block_4(block_4), slice_shape)
        out_5 = self.slice(self.up_block_5(block_5), slice_shape)
        out_6 = self.slice(self.up_block_6(block_6), slice_shape)
        results = [out_1, out_2, out_3, out_4, out_5, out_6]
        # print(out_1.shape)
        # concatenate multiscale outputs
        block_cat = torch.cat(results, dim=1)  # Bx6xHxW
        block_cat = self.block_cat(block_cat)  # Bx1xHxW

        # return results
        results.append(block_cat)
        return results

main.py

from torch import nn
from torch.utils.data import DataLoader
from dataset import BIPEDDataset
from losses import *
from config import Config
from cyclicLR import CyclicCosAnnealingLR
from model import DexiNet
import torchgeometry as tgm
import numpy as np
import time
import os
import cv2 as cv
import tqdm


def weight_init(m):
    if isinstance(m, (nn.Conv2d, )):

        torch.nn.init.normal_(m.weight,mean=0, std=0.01)
        if m.weight.data.shape[1]==torch.Size([1]):
            torch.nn.init.normal_(m.weight, mean=0.0,)
        if m.weight.data.shape==torch.Size([1,6,1,1]):
            torch.nn.init.constant_(m.weight,0.2)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    # for fusion layer
    if isinstance(m, (nn.ConvTranspose2d,)):

        torch.nn.init.normal_(m.weight,mean=0, std=0.01)
        if m.weight.data.shape[1] == torch.Size([1]):
            torch.nn.init.normal_(m.weight, std=0.1)

        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)


class Trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = DexiNet().to(self.device).apply(weight_init)
        self.criterion = weighted_cross_entropy_loss
        self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=0.003, weight_decay=0.0001)
        milestones = [5 + x * 30 for x in range(5)]
        self.scheduler = CyclicCosAnnealingLR(self.optimizer, milestones=milestones, eta_min=5e-5)
        mkdir(cfg.model_output)

    def build_loader(self):
        train_dataset = BIPEDDataset(self.cfg.train_root, config=self.cfg)
        valid_dataset = BIPEDDataset(self.cfg.valid_root, config=self.cfg)

        train_loader = DataLoader(train_dataset,
                                  batch_size=self.cfg.batch_size,
                                  num_workers=self.cfg.num_workers,
                                  shuffle=True)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=self.cfg.batch_size,
                                  num_workers=self.cfg.num_workers,
                                  shuffle=False)
        return train_loader, valid_loader

    def train_one_epoch(self, epoch, dataloader):
        self.model.train()
        for batch_id, sample_batched in tqdm.tqdm(enumerate(dataloader)):
            images = sample_batched['images'].to(self.device)  # BxCxHxW
            labels = sample_batched['labels'].to(self.device)  # BxHxW

            preds_list = self.model(images)
            loss = sum([self.criterion(preds, labels) for preds in preds_list])
            loss /= images.shape[0] # the bacth size

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            print(time.ctime(), 'training, Epoch: {0} Sample {1}/{2} Loss: {3}' \
                  .format(epoch, batch_id, len(dataloader), loss.item()), end='\r')

    def validation(self, epoch, dataloader):
        self.model.eval()
        for batch_id, sample_batched in enumerate(dataloader):
            images = sample_batched['images'].to(self.device)  # BxCxHxW
            labels = sample_batched['labels'].to(self.device)  # BxHxW
            file_name = sample_batched['file_name']

            preds_list = self.model(images)
            loss = sum([self.criterion(preds, labels) for preds in preds_list])
            loss /= images.shape[0]  # the bacth size

            print(time.ctime(), 'validation, Epoch: {0} Sample {1}/{2} Loss: {3}' \
                  .format(epoch, batch_id, len(dataloader), loss.item()), end='\r')

            self.save_image_bacth_to_disk(preds_list[-1], file_name)
            return loss

    def save_image_bacth_to_disk(self, tensor, file_names):
        output_dir = self.cfg.valid_output_dir
        mkdir(output_dir)
        assert len(tensor.shape) == 4, tensor.shape
        for tensor_image, file_name in zip(tensor, file_names):
            image_vis = tgm.utils.tensor_to_image(torch.sigmoid(tensor_image))[..., 0]
            image_vis = (255.0 * (1.0 - image_vis)).astype(np.uint8)  #
            output_file_name = os.path.join(output_dir, f"{file_name}.png")
            cv.imwrite(output_file_name, image_vis)

    def train(self):
        train_loader, valid_loader = self.build_loader()
        best_loss = 1000000
        for epoch in range(self.cfg.num_epochs):
            self.scheduler.step(epoch)

            self.model.train()
            for batch_id, sample_batched in enumerate(train_loader):
                images = sample_batched['images'].to(self.device)  # BxCxHxW
                labels = sample_batched['labels'].to(self.device)  # BxHxW

                preds_list = self.model(images)
                loss = sum([self.criterion(preds, labels) for preds in preds_list])
                loss /= images.shape[0]  # the bacth size

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                print(time.ctime(), 'training, Epoch: {0} Sample {1}/{2} Loss: {3}' \
                      .format(epoch, batch_id, len(train_loader), loss.item()), end='\r')

            valid_loss = self.validation(epoch, valid_loader)
            if valid_loss < best_loss:
                torch.save(self.model, os.path.join(self.cfg.model_output, f'epoch{epoch}_model.pth'))
                print(f'find optimal model, loss {best_loss}==>{valid_loss}')
                best_loss = valid_loss


if __name__=='__main__':
    config = Config()
    trainer = Trainer(config)
    trainer.train()



遥感影像切片
复现Dense Extreme Inception Network(pytorch)
复现Dense Extreme Inception Network(pytorch)
预测结果
复现Dense Extreme Inception Network(pytorch)

相关标签: 边缘检测