复现Dense Extreme Inception Network(pytorch)
程序员文章站
2022-03-29 22:47:35
...
github地址:https://github.com/xavysp/DexiNed/tree/master/DexiNed-Pytorch
论文地址:https://arxiv.org/abs/1909.01955
数据集:https://www.kaggle.com/xavysp/biped
摘要
这篇paper是基于深度学习的边缘检测算法,受到HED(Holistically-Nested Edge Detection)和Xception 网络的启发。该方法生成人眼可能看到的薄边缘地图,可以用于任何边缘检测任务,无需经过长时间训练或微调过程。
该论文的主要贡献:提出了一种鲁棒的CNN边缘检测架构,简称为DexiNed:Dense Extreme Inception Network for Edge Detection。这个模型是从头开始训练的,没有预先训练过的权重。
模型结构
论文实验结果
相关代码
config.py
class Config(object):
#dataset
mean_pixel_values = [104.00699, 116.66877, 122.67892, 137.86]
img_width = 400
img_height = 400
train_root = 'data/BIPED/edges/imgs/train/rgbr/real'
valid_root = 'data/BIPED/edges/edge_maps/train/rgbr/real'
valid_output_dir = 'valid_temp'
# hyper parameters
batch_size = 2
num_workers = 0
num_epochs = 25
model_output = 'result'
dataset.py
from torch.utils.data import DataLoader, Dataset
import torch
import cv2 as cv
import numpy as np
import os
class BIPEDDataset(Dataset):
def __init__(self, img_root, mode='train', config=None):
self.img_root = img_root
self.mode = mode
self.imgList = os.listdir(img_root)
self.config = config
self.mean_bgr = config.mean_pixel_values[0:3] if len(config.mean_pixel_values) == 4 \
else config.mean_pixel_values
def __len__(self):
return len(self.imgList)
def __getitem__(self, idx):
file_name = self.imgList[idx].split('.')[0]
imgPath = os.path.join(self.img_root, self.imgList[idx])
labelPath = imgPath.replace('imgs', 'edge_maps').replace('jpg', 'png')
#load data
image = cv.imread(imgPath, cv.IMREAD_COLOR)
label = cv.imread(labelPath, cv.IMREAD_GRAYSCALE)
image_shape = [image.shape[0], image.shape[1]]
image, label = self.transform(img=image, gt=label)
return dict(images=image, labels=label, file_name=file_name, image_shape=image_shape)
def transform(self, img, gt):
gt = np.array(gt, dtype=np.float32)
if len(gt.shape) == 3:
gt = gt[:, :, 0]
# gt[gt< 51] = 0 # test without gt discrimination
gt /= 255.
# if self.yita is not None:
# gt[gt >= self.yita] = 1
img = np.array(img, dtype=np.float32)
# if self.rgb:
# img = img[:, :, ::-1] # RGB->BGR
img -= self.mean_bgr
# data = []
# if self.scale is not None:
# for scl in self.scale:
# img_scale = cv.resize(img, None, fx=scl, fy=scl, interpolation=cv.INTER_LINEAR)
# data.append(torch.from_numpy(img_scale.transpose((2, 0, 1))).float())
# return data, gt
img = cv.resize(img, dsize=(self.config.img_width, self.config.img_height))
gt = cv.resize(gt, dsize=(self.config.img_width, self.config.img_height))
img = img.transpose((2, 0, 1))
img = torch.from_numpy(img.copy()).float()
gt = torch.from_numpy(np.array([gt])).float()
return img, gt
if __name__=='__main__':
from config import Config
cfg = Config()
root = 'data/BIPED/edges/imgs/train/rgbr/real'
train_dataset = BIPEDDataset(root, config=cfg)
train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0)
for data_batch in train_loader:
img, label = data_batch['images'], data_batch['labels']
print(img.size(), label.size(), data_batch['file_name'])
loss.py
import torch
import torch.nn.functional as F
def _weighted_cross_entropy_loss(preds, edges):
""" Calculate sum of weighted cross entropy loss. """
# Reference:
# hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
# https://github.com/s9xie/hed/issues/7
mask = (edges > 0.5).float()
b, c, h, w = mask.shape
num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,].
num_neg = c * h * w - num_pos # Shape: [b,].
weight = torch.zeros_like(mask)
weight[edges > 0.5] = num_neg / (num_pos + num_neg)
weight[edges <= 0.5] = num_pos / (num_pos + num_neg)
# Calculate loss.
losses = F.binary_cross_entropy_with_logits(
preds.float(), edges.float(), weight=weight, reduction='none')
loss = torch.sum(losses) / b
return loss
def weighted_cross_entropy_loss(preds, edges):
""" Calculate sum of weighted cross entropy loss. """
# Reference:
# hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
# https://github.com/s9xie/hed/issues/7
mask = (edges > 0.5).float()
b, c, h, w = mask.shape
num_pos = torch.sum(mask, dim=[1, 2, 3], keepdim=True).float() # Shape: [b,].
num_neg = c * h * w - num_pos # Shape: [b,].
weight = torch.zeros_like(mask)
#weight[edges > 0.5] = num_neg / (num_pos + num_neg)
#weight[edges <= 0.5] = num_pos / (num_pos + num_neg)
weight.masked_scatter_(edges > 0.5,
torch.ones_like(edges) * num_neg / (num_pos + num_neg))
weight.masked_scatter_(edges <= 0.5,
torch.ones_like(edges) * num_pos / (num_pos + num_neg))
# Calculate loss.
# preds=torch.sigmoid(preds)
losses = F.binary_cross_entropy_with_logits(
preds.float(), edges.float(), weight=weight, reduction='none')
loss = torch.sum(losses) / b
return loss
model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class _DenseLayer(nn.Sequential):
def __init__(self, input_features, out_features):
super(_DenseLayer, self).__init__()
# self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(input_features, out_features,
kernel_size=1, stride=1, bias=True)),
self.add_module('norm1', nn.BatchNorm2d(out_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(out_features, out_features,
kernel_size=3, stride=1, padding=1, bias=True)),
self.add_module('norm2', nn.BatchNorm2d(out_features))
# double check the norm1 comment if necessary and put norm after conv2
def forward(self, x):
x1, x2 = x
# maybe I should put here a RELU
new_features = super(_DenseLayer, self).forward(x1) # F.relu()
return 0.5 * (new_features + x2), x2
class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, input_features, out_features):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(input_features, out_features)
self.add_module('denselayer%d' % (i + 1), layer)
input_features = out_features
class UpConvBlock(nn.Module):
def __init__(self, in_features, up_scale, mode='deconv'):
super(UpConvBlock, self).__init__()
self.up_factor = 2
self.constant_features = 16
layers = None
if mode == 'deconv':
layers = self.make_deconv_layers(in_features, up_scale)
elif mode == 'pixel_shuffle':
layers = self.make_pixel_shuffle_layers(in_features, up_scale)
assert layers is not None, layers
self.features = nn.Sequential(*layers)
def make_deconv_layers(self, in_features, up_scale):
layers = []
for i in range(up_scale):
kernel_size = 2 ** up_scale
out_features = self.compute_out_features(i, up_scale)
layers.append(nn.Conv2d(in_features, out_features, 1))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.ConvTranspose2d(
out_features, out_features, kernel_size, stride=2))
in_features = out_features
return layers
def make_pixel_shuffle_layers(self, in_features, up_scale):
layers = []
for i in range(up_scale):
kernel_size = 2 ** (i + 1)
out_features = self.compute_out_features(i, up_scale)
in_features = int(in_features / (self.up_factor ** 2))
layers.append(nn.PixelShuffle(self.up_factor))
layers.append(nn.Conv2d(in_features, out_features, 1))
if i < up_scale:
layers.append(nn.ReLU(inplace=True))
in_features = out_features
return layers
def compute_out_features(self, idx, up_scale):
return 1 if idx == up_scale - 1 else self.constant_features
def forward(self, x):
return self.features(x)
class SingleConvBlock(nn.Module):
def __init__(self, in_features, out_features, stride):
super(SingleConvBlock, self).__init__()
self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride)
self.bn = nn.BatchNorm2d(out_features)
def forward(self, x):
return self.bn(self.conv(x))
class DoubleConvBlock(nn.Module):
def __init__(self, in_features, mid_features,out_features=None, stride=1):
super(DoubleConvBlock, self).__init__()
if out_features is None:
out_features = mid_features
self.conv1 = nn.Conv2d(
in_features, mid_features, 3, padding=1, stride=stride)
self.bn1 = nn.BatchNorm2d(mid_features)
self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_features)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
class DexiNet(nn.Module):
""" Definition of the DXtrem network. """
def __init__(self):
super(DexiNet, self).__init__()
self.block_1 = DoubleConvBlock(3, 32, 64, stride=2)
self.block_2 = DoubleConvBlock(64, 128)
self.dblock_3 = _DenseBlock(2, 128, 256)
self.dblock_4 = _DenseBlock(3, 256, 512)
self.dblock_5 = _DenseBlock(3, 512, 512)
self.dblock_6 = _DenseBlock(3, 512, 256)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.side_1 = SingleConvBlock(64, 128, 2)
self.side_2 = SingleConvBlock(128, 256, 2)
self.side_3 = SingleConvBlock(256, 512, 2)
self.side_4 = SingleConvBlock(512, 512, 1)
self.side_5 = SingleConvBlock(512, 256, 1)
self.pre_dense_2 = SingleConvBlock(128, 256, 2) # by me, for left skip block4
self.pre_dense_3 = SingleConvBlock(128, 256, 1)
self.pre_dense_4 = SingleConvBlock(256, 512, 1)
self.pre_dense_5_0 = SingleConvBlock(256, 512, 2)
self.pre_dense_5 = SingleConvBlock(512, 512, 1)
self.pre_dense_6 = SingleConvBlock(512, 256, 1)
self.up_block_1 = UpConvBlock(64, 1)
self.up_block_2 = UpConvBlock(128, 1)
self.up_block_3 = UpConvBlock(256, 2)
self.up_block_4 = UpConvBlock(512, 3)
self.up_block_5 = UpConvBlock(512, 4)
self.up_block_6 = UpConvBlock(256, 4)
self.block_cat = nn.Conv2d(6, 1, kernel_size=1)
def slice(self, tensor, slice_shape):
height, width = slice_shape
return tensor[..., :height, :width]
def forward(self, x):
assert len(x.shape) == 4, x.shape
# Block 1
block_1 = self.block_1(x)
block_1_side = self.side_1(block_1)
# Block 2
block_2 = self.block_2(block_1)
block_2_down = self.maxpool(block_2)
block_2_add = block_2_down + block_1_side
block_2_side = self.side_2(block_2_add)
# Block 3
block_3_pre_dense = self.pre_dense_3(block_2_down)
block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])
block_3_down = self.maxpool(block_3)
block_3_add = block_3_down + block_2_side
block_3_side = self.side_3(block_3_add)
# Block 4
block_4_pre_dense_256 = self.pre_dense_2(block_2_down)
block_4_pre_dense = self.pre_dense_4(block_4_pre_dense_256 + block_3_down)
block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])
block_4_down = self.maxpool(block_4)
block_4_add = block_4_down + block_3_side
block_4_side = self.side_4(block_4_add)
# Block 5
block_5_pre_dense_512 = self.pre_dense_5_0(block_4_pre_dense_256)
block_5_pre_dense = self.pre_dense_5(block_5_pre_dense_512 + block_4_down )
block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])
block_5_add = block_5 + block_4_side
# block_5_side = self.side_5(block_5_add)
# Block 6
block_6_pre_dense = self.pre_dense_6(block_5)
# block_5_pre_dense_256 = self.pre_dense_6(block_5_add) # if error uncomment
block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])
# upsampling blocks
height, width = x.shape[-2:]
slice_shape = (height, width)
out_1 = self.slice(self.up_block_1(block_1), slice_shape)
out_2 = self.slice(self.up_block_2(block_2), slice_shape)
out_3 = self.slice(self.up_block_3(block_3), slice_shape)
out_4 = self.slice(self.up_block_4(block_4), slice_shape)
out_5 = self.slice(self.up_block_5(block_5), slice_shape)
out_6 = self.slice(self.up_block_6(block_6), slice_shape)
results = [out_1, out_2, out_3, out_4, out_5, out_6]
# print(out_1.shape)
# concatenate multiscale outputs
block_cat = torch.cat(results, dim=1) # Bx6xHxW
block_cat = self.block_cat(block_cat) # Bx1xHxW
# return results
results.append(block_cat)
return results
main.py
from torch import nn
from torch.utils.data import DataLoader
from dataset import BIPEDDataset
from losses import *
from config import Config
from cyclicLR import CyclicCosAnnealingLR
from model import DexiNet
import torchgeometry as tgm
import numpy as np
import time
import os
import cv2 as cv
import tqdm
def weight_init(m):
if isinstance(m, (nn.Conv2d, )):
torch.nn.init.normal_(m.weight,mean=0, std=0.01)
if m.weight.data.shape[1]==torch.Size([1]):
torch.nn.init.normal_(m.weight, mean=0.0,)
if m.weight.data.shape==torch.Size([1,6,1,1]):
torch.nn.init.constant_(m.weight,0.2)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
# for fusion layer
if isinstance(m, (nn.ConvTranspose2d,)):
torch.nn.init.normal_(m.weight,mean=0, std=0.01)
if m.weight.data.shape[1] == torch.Size([1]):
torch.nn.init.normal_(m.weight, std=0.1)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
class Trainer(object):
def __init__(self, cfg):
self.cfg = cfg
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = DexiNet().to(self.device).apply(weight_init)
self.criterion = weighted_cross_entropy_loss
self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=0.003, weight_decay=0.0001)
milestones = [5 + x * 30 for x in range(5)]
self.scheduler = CyclicCosAnnealingLR(self.optimizer, milestones=milestones, eta_min=5e-5)
mkdir(cfg.model_output)
def build_loader(self):
train_dataset = BIPEDDataset(self.cfg.train_root, config=self.cfg)
valid_dataset = BIPEDDataset(self.cfg.valid_root, config=self.cfg)
train_loader = DataLoader(train_dataset,
batch_size=self.cfg.batch_size,
num_workers=self.cfg.num_workers,
shuffle=True)
valid_loader = DataLoader(valid_dataset,
batch_size=self.cfg.batch_size,
num_workers=self.cfg.num_workers,
shuffle=False)
return train_loader, valid_loader
def train_one_epoch(self, epoch, dataloader):
self.model.train()
for batch_id, sample_batched in tqdm.tqdm(enumerate(dataloader)):
images = sample_batched['images'].to(self.device) # BxCxHxW
labels = sample_batched['labels'].to(self.device) # BxHxW
preds_list = self.model(images)
loss = sum([self.criterion(preds, labels) for preds in preds_list])
loss /= images.shape[0] # the bacth size
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
print(time.ctime(), 'training, Epoch: {0} Sample {1}/{2} Loss: {3}' \
.format(epoch, batch_id, len(dataloader), loss.item()), end='\r')
def validation(self, epoch, dataloader):
self.model.eval()
for batch_id, sample_batched in enumerate(dataloader):
images = sample_batched['images'].to(self.device) # BxCxHxW
labels = sample_batched['labels'].to(self.device) # BxHxW
file_name = sample_batched['file_name']
preds_list = self.model(images)
loss = sum([self.criterion(preds, labels) for preds in preds_list])
loss /= images.shape[0] # the bacth size
print(time.ctime(), 'validation, Epoch: {0} Sample {1}/{2} Loss: {3}' \
.format(epoch, batch_id, len(dataloader), loss.item()), end='\r')
self.save_image_bacth_to_disk(preds_list[-1], file_name)
return loss
def save_image_bacth_to_disk(self, tensor, file_names):
output_dir = self.cfg.valid_output_dir
mkdir(output_dir)
assert len(tensor.shape) == 4, tensor.shape
for tensor_image, file_name in zip(tensor, file_names):
image_vis = tgm.utils.tensor_to_image(torch.sigmoid(tensor_image))[..., 0]
image_vis = (255.0 * (1.0 - image_vis)).astype(np.uint8) #
output_file_name = os.path.join(output_dir, f"{file_name}.png")
cv.imwrite(output_file_name, image_vis)
def train(self):
train_loader, valid_loader = self.build_loader()
best_loss = 1000000
for epoch in range(self.cfg.num_epochs):
self.scheduler.step(epoch)
self.model.train()
for batch_id, sample_batched in enumerate(train_loader):
images = sample_batched['images'].to(self.device) # BxCxHxW
labels = sample_batched['labels'].to(self.device) # BxHxW
preds_list = self.model(images)
loss = sum([self.criterion(preds, labels) for preds in preds_list])
loss /= images.shape[0] # the bacth size
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
print(time.ctime(), 'training, Epoch: {0} Sample {1}/{2} Loss: {3}' \
.format(epoch, batch_id, len(train_loader), loss.item()), end='\r')
valid_loss = self.validation(epoch, valid_loader)
if valid_loss < best_loss:
torch.save(self.model, os.path.join(self.cfg.model_output, f'epoch{epoch}_model.pth'))
print(f'find optimal model, loss {best_loss}==>{valid_loss}')
best_loss = valid_loss
if __name__=='__main__':
config = Config()
trainer = Trainer(config)
trainer.train()
遥感影像切片
预测结果
上一篇: 工业检测中铁罐焊缝检测
下一篇: 边缘检测在android上的小应用