欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

深度学习实战 | MMDetection之FCOS代码

程序员文章站 2022-04-07 13:58:55
深度学习实战 | MMDetection之FCOS代码(1)...


1. 简介

本系列将基于 M M D e t e c t i o n {\rm MMDetection} MMDetection介绍 A n c h o r {\rm Anchor} Anchor- F r e e {\rm Free} Free目标检测算法 F C O S {\rm FCOS} FCOS的实现细节。我们直接从配置文件训练文件为入口介绍该算法的整体流程以及实现细节。


2. FCOS简介

深度学习实战 | MMDetection之FCOS代码

图1:FCOS结构

F C O S {\rm FCOS} FCOS大体由图上三部分组成:骨干网络,特征金字塔和预测分支。基于语义分割算法 F C N {\rm FCN} FCN的思想, F C O S {\rm FCOS} FCOS将目标检测看作像素级检测,针对特征图上每个点映射回原图后的信息将其视为训练样本。下面结合这三部分以及配置文件来说明 F C O S {\rm FCOS} FCOS的具体实现。 配置文件链接 训练文件链接


3. 整体训练流程

首先进入训练文件中的主函数,第一个重要的函数是类Config的成员函数fromfile,它以具体的配置文件路径为参数,然后以字典的形式返回配置文件的内容。如配置文件的前面一部分为:

model = dict(
    type='FCOS',
    pretrained='open-mmlab://detectron/resnet50_caffe',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=False),
        norm_eval=True,
        style='caffe'),
    ...
)

返回的结果为:

{'type': 'FCOS', 'pretrained': 'open-mmlab://detectron/resnet50_caffe', 
'backbone': {'type': 'ResNet', 'depth': 50, 'num_stages': 4, 'out_indices': (0, 1, 2, 3), ...}}

然后是一些训练过程中的配置信息,包括设置工作目录、恢复训练、 G P U {\rm GPU} GPU索引、分布式训练、日志信息等。下面一个重要的函数是build_detector函数,其功能是根据配置文件的信息构建相应的检测器。

model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

接着是build_dataset函数,其功能是根据配置文件的信息构建相应的数据集。

datasets = [build_dataset(cfg.data.train)]

最后是train_detector函数,其功能是根据模型、数据集和配置文件信息等启动训练。

train_detector(model, datasets, cfg, distributed=distributed, 
               validate=(not args.no_validate), timestamp=timestamp, meta=meta)

以上是训练文件中的几个关键函数,后面三个函数其实都是通过build函数实现,所以首先介绍build函数。


4. build函数

首先在文件开头定义了Registry类的对象链接

BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

而在调用build函数时就是根据上述对象分别实现具体的功能。如build_detector函数:

def build_detector(cfg, train_cfg=None, test_cfg=None):
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)

def build_from_cfg(cfg, registry, default_args=None):
    # cfg格式必须是字典形式
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    # type必须存在,其指定了具体的检测器
    if 'type' not in cfg:
        raise KeyError(f'the cfg dict must contain the key "type", but got {cfg}')
    # registry必须是已经定义的Registry对象
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, but got {type(registry)}')
    # default_args参数必须为字典或None
    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, but got {type(default_args)}')
    # 字典的浅复制
    args = cfg.copy()
    # 根据type字段获得具体的检测器,如FCOS,每个检测器都有与之对应的类,后续通过类的构造函数返回
    obj_type = args.pop('type')
    if is_str(obj_type):
        # 判断obj_type是否存在于注册表中
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError('{obj_type} is not in the {registry.name} registry')
    # 判断obj_type是否为类
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    # 不存在该检测器
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')
	# 添加default_args的内容
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    # 根据类名调用相应的构造函数,参数是**args
    return obj_cls(**args)

下面根据配置文件中的type参数来介绍主要的类。

4.1 FCOS类

链接 F C O S {\rm FCOS} FCOS是一种单阶段的检测器,该类的主要继承关系如下:

type='FCOS'

class FCOS(SingleStageDetector):
    def __init__(self, backbone, neck, bbox_head, train_cfg=None, test_cfg=None, pretrained=None):
        super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained)

class SingleStageDetector(BaseDetector):
    def __init__(self, backbone, neck=None, bbox_head=None, train_cfg=None, test_cfg=None, pretrained=None):
        super(SingleStageDetector, self).__init__()

class BaseDetector(nn.Module, metaclass=ABCMeta):
    def __init__(self):
        super(BaseDetector, self).__init__()

最终基类BaseDetector继承自nn.Module和元类ABCMeta,该基类派生出SingleStageDetectorTwoStageDetector两个子类,因此它是所有类的基类。

4.2 ResNet类

链接ResNet类直接继承自nn.Module

class ResNet(nn.Module):
	# 支持ResNet18、ResNet34、ResNet50、ResNet101、ResNet152
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }
    def __init__(self, depth, in_channels=3, stem_channels=None, base_channels=64, num_stages=4,
                 strides=(1, 2, 2, 2), dilations=(1, 1, 1, 1), out_indices=(0, 1, 2, 3), 
                 style='pytorch', deep_stem=False, avg_down=False, frozen_stages=-1, conv_cfg=None, 
                 norm_cfg=dict(type='BN', requires_grad=True), norm_eval=True, dcn=None,
                 stage_with_dcn=(False, False, False, False), plugins=None, with_cp=False,
                 zero_init_residual=True):
        super(ResNet, self).__init__()

4.4 FPN类

链接FPN类直接继承自nn.Module

class FPN(nn.Module):
    def __init__(self, in_channels, out_channels, num_outs, start_level=0, end_level=-1,
                 add_extra_convs=False, extra_convs_on_inputs=True, relu_before_extra_convs=False,
                 no_norm_on_lateral=False, conv_cfg=None, norm_cfg=None, act_cfg=None,
                 upsample_cfg=dict(mode='nearest')):
        super(FPN, self).__init__()

4.5 FCOSHead类

链接 F C O S {\rm FCOS} FCOS是一种无框检测器,该类的继承关系如下:

class FCOSHead(AnchorFreeHead):
    def __init__(self, num_classes, in_channels,
                 regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512), (512, INF)),
                 center_sampling=False, center_sample_radius=1.5, norm_on_bbox=False, centerness_on_reg=False,
                 loss_cls=dict(type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0),
                 loss_bbox=dict(type='IoULoss', loss_weight=1.0),
                 loss_centerness=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
                 norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), **kwargs):
        super().__init__( num_classes, in_channels, loss_cls=loss_cls, loss_bbox=loss_bbox,
            norm_cfg=norm_cfg, **kwargs)

class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
    def __init__(self, num_classes, in_channels, feat_channels=256, stacked_convs=4,
                 strides=(4, 8, 16, 32, 64), dcn_on_last_conv=False, conv_bias='auto',
                 loss_cls=dict(type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0),
                 loss_bbox=dict(type='IoULoss', loss_weight=1.0),
                 conv_cfg=None, norm_cfg=None, train_cfg=None, test_cfg=None):
        super(AnchorFreeHead, self).__init__()

class BaseDenseHead(nn.Module, metaclass=ABCMeta):
    def __init__(self):
        super(BaseDenseHead, self).__init__()

class BBoxTestMixin(object):
    pass

4.6 FocalLoss类

链接FocalLoss类直接继承自nn.Module

class FocalLoss(nn.Module):

    def __init__(self, use_sigmoid=True, gamma=2.0, alpha=0.25, reduction='mean', loss_weight=1.0):
        super(FocalLoss, self).__init__()

4.7 IoULoss类

链接IoULoss类直接继承自nn.Module

class IoULoss(nn.Module):
    def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
        super(IoULoss, self).__init__()

4.8 CrossEntropyLoss类

链接CrossEntropyLoss类直接继承自nn.Module

class CrossEntropyLoss(nn.Module):

    def __init__(self, use_sigmoid=False, use_mask=False, reduction='mean',
                 class_weight=None, loss_weight=1.0):
        super(CrossEntropyLoss, self).__init__()

4.9 MaxIoUAssigner类

链接MaxIoUAssigner类的继承关系如下:

class MaxIoUAssigner(BaseAssigner):
    def __init__(self, pos_iou_thr, neg_iou_thr, min_pos_iou=.0, gt_max_assign_all=True,
                 ignore_iof_thr=-1, ignore_wrt_candidates=True, match_low_quality=True, gpu_assign_thr=-1,
                 iou_calculator=dict(type='BboxOverlaps2D')):

class BaseAssigner(metaclass=ABCMeta):
    pass

4.10 CocoDataset类

根据配置文件的名称调用相应的数据处理类,链接CocoDataset类的继承关系如下:

class CocoDataset(CustomDataset):
    pass

class CustomDataset(Dataset):
    def __init__(self, ann_file, pipeline, classes=None, data_root=None, img_prefix='',
                 seg_prefix=None, proposal_file=None, test_mode=False, filter_empty_gt=True):

5. Registry类

M M D e t e c t i o n {\rm MMDetection} MMDetection中使用了注册器来管理模型中的各模块,维护一个模块名称 => 模块类的字典,当需要新添加一个模块时,我们仅需维护注册代码的路径而不需要手动修改字典。关键部分在Registry类中完成,链接。下面是Registry类的关键部分:

class Registry:
    def __init__(self, name):
        self._name = name
        # self._module_dict用于存放表示模块名称的字符串到模块类的映射
        self._module_dict = dict()
        
    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
    	# 根据字符串取相应的类
        return self._module_dict.get(key, None)

    def _register_module(self, module_class, module_name=None, force=False):
    	# 判断是否为类
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')
		# 开始注册模块
        if module_name is None:
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {self.name}')
        # 将该模块信息加入字典内完成注册
        self._module_dict[module_name] = module_class

    def register_module(self, name=None, force=False, module=None):
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {type(force)}')
            
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)

        # 注册模块,形式为x.register_module(module=SomeClass)
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module
        # 模块名必须为字符串
        if not (name is None or isinstance(name, str)):
            raise TypeError(f'name must be a str, but got {type(name)}')
        # 注册模块,形式为使用Python的装饰器,@x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls
        return _register

上面定义了两种注册模块的形式,即普通函数调用和使用装饰器,如:

# 函数调用
backbones = Registry("backbone")
class ResNet:
    pass
backbones.register_module(ResNet)
# 语法糖
backbones = Register("backbone")
@backbones.register_module()
class ResNet:
    pass

这里,装饰器的使用就与上面builder.py中的Registry类对象相对应,每个对象完成不同的模块的注册:

BACKBONES = Registry('backbone')	# 骨干网络
NECKS = Registry('neck')	# 颈
ROI_EXTRACTORS = Registry('roi_extractor')	# RoI提取器
SHARED_HEADS = Registry('shared_head')	# 共享头
HEADS = Registry('head')	# 头
LOSSES = Registry('loss')	# 损失函数
DETECTORS = Registry('detector')	# 检测器

6. Config类

链接Config类主要用于解析配置文件,得到训练参数,并以字典的形式返回。Config类的关键部分如下:

class Config:
	# 使用抽象语法树AST判断文件的语法是否正确
    @staticmethod
    def _validate_py_syntax(filename):
        with open(filename, 'r') as f:
            content = f.read()
        try:
            ast.parse(content)
        except SyntaxError as e:
            raise SyntaxError('There are syntax errors in config '
                              f'file {filename}: {e}')
	
    @staticmethod
    def _substitute_predefined_vars(filename, temp_config_name):
        file_dirname = osp.dirname(filename)	# 文件所在目录
        file_basename = osp.basename(filename)	# 文件名(包括扩展名)
        file_basename_no_extension = osp.splitext(file_basename)[0]	# 文件名
        file_extname = osp.splitext(filename)[1]	# 文件扩展名
        support_templates = dict(	# 存入字典
            fileDirname=file_dirname,
            fileBasename=file_basename,
            fileBasenameNoExtension=file_basename_no_extension,
            fileExtname=file_extname)
        with open(filename, 'r') as f:
            config_file = f.read()
        # 将config_file的内容使用正则处理配置文件
        for key, value in support_templates.items():
            regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
            value = value.replace('\\', '/')
            config_file = re.sub(regexp, value, config_file)
        with open(temp_config_name, 'w') as tmp_config_file:
            tmp_config_file.write(config_file)

    @staticmethod
    def _file2dict(filename, use_predefined_variables=True):
    	# 在Linux系统下将~替换成绝对路径
        filename = osp.abspath(osp.expanduser(filename))
        check_file_exist(filename)
        # 只支持.py、.json、.yaml和.yml文件的解析
        fileExtname = osp.splitext(filename)[1]
        if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
            raise IOError('Only py/yml/yaml/json type are supported now!')
		# 创建临时文件
        with tempfile.TemporaryDirectory() as temp_config_dir:
        	# 给临时文件命名
            temp_config_file = tempfile.NamedTemporaryFile(
                dir=temp_config_dir, suffix=fileExtname)
            if platform.system() == 'Windows':
                temp_config_file.close()
            # 临时配置文件名
            temp_config_name = osp.basename(temp_config_file.name)
            # Substitute predefined variables
            if use_predefined_variables:
                Config._substitute_predefined_vars(filename,
                                                   temp_config_file.name)
            else:
                shutil.copyfile(filename, temp_config_file.name)
			# 处理以.py结尾的配置文件
            if filename.endswith('.py'):
            	# 去除扩展名外的文件名
                temp_module_name = osp.splitext(temp_config_name)[0]
                # 将临时文件目录加入系统路径
                sys.path.insert(0, temp_config_dir)
                # 检验文件内容是否合法
                Config._validate_py_syntax(filename)
                # 以绝对导入的方式导入临时模块
                mod = import_module(temp_module_name)
                # 弹出系统路径的第一项
                sys.path.pop(0)
                # 将配置文件内容转化成字典形式
                cfg_dict = {
                    name: value
                    for name, value in mod.__dict__.items()
                    if not name.startswith('__')
                }
                # 删除导入的临时模块
                del sys.modules[temp_module_name]
            # 处理其他形式结尾的配置文件
            elif filename.endswith(('.yml', '.yaml', '.json')):
                import mmcv
                cfg_dict = mmcv.load(temp_config_file.name)
            # 关闭临时文件
            temp_config_file.close()
		# 路径信息
        cfg_text = filename + '\n'
        # 路径信息 + 配置文件内容
        with open(filename, 'r') as f:
            cfg_text += f.read()
		# BASE_KEY = "_base_",即查看该配置文件是否有继承
		# 如'../_base_/datasets/coco_detection.py',
    	#   '../_base_/schedules/schedule_1x.py', 
    	#   '../_base_/default_runtime.py'
    	# 存在三个基配置文件
        if BASE_KEY in cfg_dict:
            cfg_dir = osp.dirname(filename)
            base_filename = cfg_dict.pop(BASE_KEY)
            base_filename = base_filename if isinstance(
                base_filename, list) else [base_filename]

            cfg_dict_list = list()
            cfg_text_list = list()
            # 以递归的方式将基配置文件的内容转换成字典
            for f in base_filename:
                _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
                cfg_dict_list.append(_cfg_dict)
                cfg_text_list.append(_cfg_text)
            base_cfg_dict = dict()
            for c in cfg_dict_list:
                if len(base_cfg_dict.keys() & c.keys()) > 0:
                    raise KeyError('Duplicate key is not allowed among bases')
                base_cfg_dict.update(c)
			# 将字典cfg_dict和字典base_cfg_dict合并
            base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
            cfg_dict = base_cfg_dict
            # 合并cfg_text内容
            cfg_text_list.append(cfg_text)
            cfg_text = '\n'.join(cfg_text_list)
		# 返回解析内容
        return cfg_dict, cfg_text

最后可以得到解析内容为:

深度学习实战 | MMDetection之FCOS代码

上图包含了训练模型所需要的所有参数。


7. 总结

本文以 F C O S {\rm FCOS} FCOS模型为基础,介绍在 M M D e t e c t i o n {\rm MMDetection} MMDetection中实现 F C O S {\rm FCOS} FCOS所需要实现的类,下文将详细介绍实现 F C O S {\rm FCOS} FCOS的类。


参考

  1. https://github.com/open-mmlab/mmdetection.


本文地址:https://blog.csdn.net/Skies_/article/details/109550241