欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

mmdetection2.0 | train.py相关源码详解(一)

程序员文章站 2022-06-12 19:38:08
...

本文章从源码入手,详细的剖析 MMDetection 的训练细节。理解它是如何实现通用,灵活的训练的。
MMDetection 支持单机单卡、多卡训练。而且对于训练部分的代码和模型、数据集的代码的耦合性极低。真正实现了支持各种训练模式、各种模型和数据集的通用训练模板。

一、tools/train.py

tools/train.py单机单卡训练时需要运行的文件,其主要使用方法为:

python tools/train.py ${CONFIG_FILE} [optional arguments]

如果为单机多卡,需要运行 tools/dist_train.sh,注意此文件不支持多机多卡训练:

./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]

可选参数如下:

 =========== optional arguments ===========
# --work-dir        存储日志和模型的目录
# --resume-from     加载 checkpoint 的目录
# --no-validate     是否在训练的时候进行验证
# 互斥组:
#   --gpus          使用的 GPU 数量
#   --gpu_ids       使用指定 GPU 的 id
# --seed            随机数种子
# --deterministic   是否设置 cudnn 为确定性行为
# --options         其他参数
# --launcher        分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
#                   none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
# --local_rank      本地进程编号,此参数 torch.distributed.launch 会自动传入。

我们来看一下 dist_train.sh 里面究竟是什么?

可以看出 dist_train.sh 的本质就是使用 torch.distributed.launch(这是分布式的辅助启动工具) 运行 tools/train.py。

torch.distributed.launch 需要使用 python -m 来运行,-m 是把一个模块当做脚本来运行的一个参数。一般情况下可以给 torch.distributed.launch 传如下的参数

--nproc_per_node:表示每台机器的 GPU 数量
--nnodes:表示机器的数量
--node_rank:机器的排名,如果为 0 代表是 master 节点(机器)。
--master_addr:master 节点的 IP 地址
--master_port:master 节点开放的端口号

可以看到,因为 dist_train.sh 只支持单机多卡训练,所以只传参了 --nproc_per_node(GPU 个数)和 --master_port(开放的端口号)

#!/usr/bin/env bash

CONFIG=$1
GPUS=$2
PORT=${PORT:-29500}

PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
    $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}

对于 tools/train.py 其主要的流程为:

(一)从命令行和配置文件获取参数配置

(二)构建模型

# 构建模型: 需要传入 cfg.model,cfg.train_cfg,cfg.test_cfg
model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

(三)构建数据集

# 构建数据集: 需要传入 cfg.data.train,表明是训练集
datasets = [build_dataset(cfg.data.train)]

(四)训练模型

# 训练检测器:需要传入模型、数据集、配置参数等
train_detector(
    model,
    datasets,
    cfg,
    distributed=distributed,
    validate=(not args.no_validate),
    timestamp=timestamp,
    meta=meta)

所以对于 train.py 来说,首先从命令行和配置文件读取配置,然后分别用 build_detectorbuild_dataset 构建模型和数据集,最后将模型和数据集传入 train_detector 进行训练。

下面我们来看一下源码:

import argparse
import copy
import os
import os.path as osp
import time
import warnings

import mmcv
import torch
# Config 用于读取配置文件, DictAction 将命令行字典类型参数转化为 key-value 形式
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash

from mmdet import __version__
from mmdet.apis import set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger

# python tools/train.py ${CONFIG_FILE} [optional arguments]

# =========== optional arguments ===========
# --work-dir        存储日志和模型的目录
# --resume-from     加载 checkpoint 的目录
# --no-validate     是否在训练的时候进行验证
# 互斥组:
#   --gpus          使用的 GPU 数量
#   --gpu_ids       使用指定 GPU 的 id
# --seed            随机数种子
# --deterministic   是否设置 cudnn 为确定性行为
# --options         其他参数
# --launcher        分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
#                   none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
# --local_rank      本地进程编号,此参数 torch.distributed.launch 会自动传入。


def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--work-dir', help='the dir to save logs and models')
    parser.add_argument(
        '--resume-from', help='the checkpoint file to resume from')
    # action: store (默认, 表示保存参数)
    # action: store_true, store_false (如果指定参数, 则为 True, False)
    parser.add_argument(
        '--no-validate',
        action='store_true',
        help='whether not to evaluate the checkpoint during training')

    # --------- 创建一个互斥组. argparse 将会确保互斥组中的参数只能出现一个 ---------
    group_gpus = parser.add_mutually_exclusive_group()
    group_gpus.add_argument(
        '--gpus',
        type=int,
        help='number of gpus to use '
        '(only applicable to non-distributed training)')
    # 可以使用 python train.py --gpu-ids 0 1 2 3 指定使用的 GPU id
    # 参数结果:[0, 1, 2, 3]
    # nargs = '*':参数个数可以设置0个或n个
    # nargs = '+':参数个数可以设置1个或n个
    # nargs = '?':参数个数可以设置0个或1个
    group_gpus.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='ids of gpus to use '
        '(only applicable to non-distributed training)')
    # ------------------------------------------------------------------------

    parser.add_argument('--seed', type=int, default=None, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
    # 其他参数: 可以使用 --options a=1,2,3 指定其他参数
    # 参数结果: {'a': [1, 2, 3]}
    parser.add_argument(
        '--options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file (deprecate), '
        'change to --cfg-options instead.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    # 如果使用 dist_utils.sh 进行分布式训练, launcher 默认为 pytorch
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    # 本地进程编号,此参数 torch.distributed.launch 会自动传入。
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    # 如果环境中没有 LOCAL_RANK,就设置它为当前的 local_rank
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    if args.options and args.cfg_options:
        raise ValueError(
            '--options and --cfg-options cannot be both '
            'specified, --options is deprecated in favor of --cfg-options')
    if args.options:
        warnings.warn('--options is deprecated in favor of --cfg-options')
        args.cfg_options = args.options

    return args


def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)  # 从文件读取配置
    # 从命令行读取额外的配置
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # import modules from string list.
    if cfg.get('custom_imports', None):
        from mmcv.utils import import_modules_from_strings
        import_modules_from_strings(**cfg['custom_imports'])
    # set cudnn_benchmark,设置True 可以加速输入大小固定的模型. 如:SSD300
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    # work_dir 的优先程度为: 命令行 > 配置文件
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    # 当 work_dir 为 None 的时候, 使用 ./work_dir/配置文件名 作为默认工作目录
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        # os.path.basename(path)    返回文件名
        # os.path.splitext(path)    分割路径, 返回路径名和文件扩展名的元组
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    # 是否继续上次的训练
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    # gpu id
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    # init distributed env first, since logger depends on the dist info.
    # 如果 launcher 为 none,不启用分布式训练。不使用 dist_train.sh 默认参数为 none.
    if args.launcher == 'none':
        distributed = False
    # launcher 不为 none,启用分布式训练。使用 dist_train.sh,会传 ‘pytorch’
    else:
        distributed = True
        # 初始化 dist 里面会调用 init_process_group
        init_dist(args.launcher, **cfg.dist_params)
        # re-set gpu_ids with distributed training mode
        _, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info
    meta['config'] = cfg.pretty_text
    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')

    # 设置随机化种子
    if args.seed is not None:
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed
    meta['exp_name'] = osp.basename(args.config)

    # 构建模型: 需要传入 cfg.model, cfg.train_cfg, cfg.test_cfg
    model = build_detector(
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))
    model.init_weights()

    # 构建数据集: 需要传入 cfg.data.train
    datasets = [build_dataset(cfg.data.train)]
    # workflow 代表流程:
    # [('train', 2), ('val', 1)] 就代表,训练两个 epoch 验证一个 epoch
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmdet_version=__version__ + get_git_hash()[:7],
            CLASSES=datasets[0].CLASSES)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES

    # 训练检测器, 传入:模型, 数据集, config 等
    train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=(not args.no_validate),
        timestamp=timestamp,
        meta=meta)


if __name__ == '__main__':
    main()

我们分别来看看与 train.py 相关的核心函数

1、init_dist

此函数负责调用 init_process_group,完成分布式的初始化。在运行 dist_train.py 训练时,默认传递的 launcher 是 ‘pytorch’。所以此函数会进一步调用 _init_dist_pytorch 来完成初始化。

因为 torch.distributed 可以采用单进程控制多 GPU,也可以一个进程控制一个 GPU。一个进程控制一个 GPU 是目前Pytorch中,无论是单节点还是多节点,进行数据并行训练最快的方式。在 mmdet 中也是这么实现的。既然是单个进程控制单个 GPU,那么我么就需要绑定当前进程控制的是哪个 GPU。可以理解为在使用 torch.distributed.launch 运行 py 文件时。 它会多次调用 py 文件,每个 py 文件控制一个 GPU。并向每个 py 文件传参 --local_rank。(local_rank 是在这台机器上的本地进程编号)这样对于每个 py 文件,都能拿到传入的本地进程编号,我们只需要把当前进程绑定到指定的 GPU 即可。

在 _init_dist_pytorch 中就会设置当前进程控制的默认 GPU(torch.cuda.set_device),再使用 dist.init_process_group 初始化,初始化的方式为默认的 env://,即环境变量的方式。使用 env:// 方式初始化就需要用 torch.distributed.launch 运行 py 文件,torch.distributed.launch 会根据传入的参数设置环境变量,并运行 py 文件。

# Copyright (c) Open-MMLab. All rights reserved.
import functools
import os
import subprocess

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from mmcv.utils import TORCH_VERSION


def init_dist(launcher, backend='nccl', **kwargs):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')
    # 默认进到这里
    if launcher == 'pytorch':
        # 调用下面的 _init_dist_pytorch 函数来初始化。
        _init_dist_pytorch(backend, **kwargs)
    elif launcher == 'mpi':
        _init_dist_mpi(backend, **kwargs)
    elif launcher == 'slurm':
        _init_dist_slurm(backend, **kwargs)
    else:
        raise ValueError(f'Invalid launcher type: {launcher}')


def _init_dist_pytorch(backend, **kwargs):
    # rank 是所有进程的总编号,算上本地进程,从 0 开始。
    rank = int(os.environ['RANK'])
    num_gpus = torch.cuda.device_count()
    # 这里也可以使用命令行传递来的 --local_rank。
    torch.cuda.set_device(rank % num_gpus)
    dist.init_process_group(backend=backend, **kwargs)

2、set_random_seed:

此函数会对 python、numpy、torch 都设置随机数种子。

保持随机数种子相同时,卷积的结果在CPU上相同,在GPU上仍然不相同。这是因为,cudnn卷积行为的不确定性。使用 torch.backends.cudnn.deterministic = True 可以解决。

cuDNN 使用非确定性算法,并且可以使用 torch.backends.cudnn.enabled = False 来进行禁用。如果设置为 torch.backends.cudnn.enabled = True,说明设置为使用非确定性算法(即会自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题)

一般来讲,应该遵循以下准则:

  1. 如果网络的输入数据维度或类型上变化不大,设置torch.backends.cudnn.benchmark = true 可以增加运行效率
  2. 如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率。设置
    torch.backends.cudnn.benchmark = False 避免重复搜索。
def set_random_seed(seed, deterministic=False):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # manual_seed_all 是为所有 GPU 都设置随机数种子。
    torch.cuda.manual_seed_all(seed)

    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

3、get_root_logger

get_root_logger 调用 get_logger 函数获取 logger 对象。

import logging

from mmcv.utils import get_logger


def get_root_logger(log_file=None, log_level=logging.INFO):
    logger = get_logger(name='mmdet', log_file=log_file, log_level=log_level)

    return logger

这里实现的 get_logger 函数非常灵活,如果传入相同的 log 的 name,会返回配置相同的 log。传入以点分割的日志名称的子模块,也会返回相同的 log。如:a 和 a.b 会返回相同的 log。

如果传入 log_file 会保存 log 的输出到 log_file 指定的路径,如果不传入 log_file,不保存日志的输出。只在控制台输出。

下面我们来分析一下源码:

import logging

import torch.distributed as dist

# 记录是否创建过 name 对应的 log,如果创建过设置为 True
logger_initialized = {}


def get_logger(name, log_file=None, log_level=logging.INFO):
    # 获取 log 对象。
    logger = logging.getLogger(name)
    # 如果已经创建过,直接返回
    if name in logger_initialized:
        return logger
    # 如果是创建过的以 ‘.’ 分割的子模块,也直接返回
    for logger_name in logger_initialized:
        if name.startswith(logger_name):
            return logger

    stream_handler = logging.StreamHandler()
    handlers = [stream_handler]

    # 获取当前的 rank(总进程编号)
    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
    else:
        rank = 0

    # 只有 rank 0(master 节点的 local_rank 为 0 的进程)的主机才保存日志
    if rank == 0 and log_file is not None:
        file_handler = logging.FileHandler(log_file, 'w')
        handlers.append(file_handler)

    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(log_level)
        logger.addHandler(handler)
    # 对于非 rank 为 0 的进程,只有 error 以上的信息才会显示
    if rank == 0:
        logger.setLevel(log_level)
    else:
        logger.setLevel(logging.ERROR)
    # 将 log name 对应的值设为 True,表示创建过。
    logger_initialized[name] = True

    return logger

因为在 train.py 中主要调用:构建模型(build_detector),构建数据集(build_dataset),训练模型(train_detector)的函数,我们下来分别看看这三个函数的源码。

二、build_detector(mmdet/models/builder.py)

build_detector 函数将配置文件config中的:model、train_cfg 和 test_cfg 三部分传入参数。

下面以 faster_rcnn_r50_fpn_1x_coco.py 配置文件来举例:

具体在faster_rcnn_r50_fpn.py文件中

model

model = dict(
    type='FasterRCNN',
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=True),
        norm_eval=True,
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=80,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0))))

train_cfg

train_cfg = dict(
    rpn=dict(
        assigner=dict(
            type='MaxIoUAssigner',
            pos_iou_thr=0.7,
            neg_iou_thr=0.3,
            min_pos_iou=0.3,
            match_low_quality=True,
            ignore_iof_thr=-1),
        sampler=dict(
            type='RandomSampler',
            num=256,
            pos_fraction=0.5,
            neg_pos_ub=-1,
            add_gt_as_proposals=False),
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    rpn_proposal=dict(
        nms_across_levels=False,
        nms_pre=2000,
        nms_post=1000,
        max_num=1000,
        nms_thr=0.7,
        min_bbox_size=0),
    rcnn=dict(
        assigner=dict(
            type='MaxIoUAssigner',
            pos_iou_thr=0.5,
            neg_iou_thr=0.5,
            min_pos_iou=0.5,
            match_low_quality=False,
            ignore_iof_thr=-1),
        sampler=dict(
            type='RandomSampler',
            num=512,
            pos_fraction=0.25,
            neg_pos_ub=-1,
            add_gt_as_proposals=True),
        pos_weight=-1,
        debug=False))

test_cfg

test_cfg = dict(
    rpn=dict(
        nms_across_levels=False,
        nms_pre=1000,
        nms_post=1000,
        max_num=1000,
        nms_thr=0.7,
        min_bbox_size=0),
    rcnn=dict(
        score_thr=0.05,
        nms=dict(type='nms', iou_threshold=0.5),
        max_per_img=100)
    # soft-nms is also supported for rcnn testing
    # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
)

运行时会将上面的三个值作为参数传入 build_detector 函数,build_detector 函数会调用 build 函数,build 函数调用 build_from_cfg 函数构建检测器对象。其中 train_cfgtest_cfg 作为默认参数用于构建 detector 对象。

def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        # 调用 build_from_cfg 用来根据 config 字典构建 registry 里面的对象
        return build_from_cfg(cfg, registry, default_args)


def build_detector(cfg, train_cfg=None, test_cfg=None):
    # 调用 build 函数,传入 cfg, registry 对象,
    # 把 train_cfg 和 test_cfg 作为默认字典传入
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

build_from_cfgmmcv/utils/registery.py 中。其中参数 cfg 字典中的 type 键所对应的值表示需要创建的对象的类型。build_from_cfg 会自动在 Registry 注册的类中找到需要创建的类,并传入默认参数实例化。

def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        object: The constructed object.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    if 'type' not in cfg:
        raise KeyError(
            f'the cfg dict must contain the key "type", but got {cfg}')
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, '
                        f'but got {type(registry)}')
    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, '
                        f'but got {type(default_args)}')

    args = cfg.copy()
    # 获取 type 对应的值
    obj_type = args.pop('type')
    if is_str(obj_type):
        # 获取需要创建的对象
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')

    # 如果 default_args 不是 None,传入默认值再实例化。
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_cls(**args)

那么什么是 registry?
registry 就是注册类,将一个字符串和类关联起来。如果索引字符串就会获得类。Registry 是注册所需要的类,可以用它来注册类。我们可以使用如下的方式来注册类。

backbones = Registry('backbone')
@backbones.register_module()
class ResNet:
    pass

backbones = Registry('backbone')
@backbones.register_module(name='mnet')
class MobileNet:
    pass

backbones = Registry('backbone')
class ResNet:
    pass
backbones.register_module(ResNet)

下面是 Registry 类的代码,它的内部维护了一个已经注册的类的字典 ——_module_dict。每当注册一个类就在字典里添加一个字符串(默认为类名)与类的映射。register_module 方法,利用装饰器将类名和类添加到 _module_dict 中。对于注册的模块可以通过 build_from_cfg 来构建。

import inspect
import warnings
from functools import partial

from .misc import is_str


class Registry:
    """A registry to map strings to classes.

    Args:
        name (str): Registry name.
    """

    def __init__(self, name):
        self._name = name
        # 已经注册的类的字典
        self._module_dict = dict()

    def __len__(self):
        return len(self._module_dict)

    def __contains__(self, key):
        return self.get(key) is not None

    def __repr__(self):
        format_str = self.__class__.__name__ + \
                     f'(name={self._name}, ' \
                     f'items={self._module_dict})'
        return format_str

    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
        """Get the registry record.

        Args:
            key (str): The class name in string format.

        Returns:
            class: The corresponding class.
        """
        return self._module_dict.get(key, None)

    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {self.name}')
        self._module_dict[module_name] = module_class

    def deprecated_register_module(self, cls=None, force=False):
        warnings.warn(
            'The old API of register_module(module, force=False) '
            'is deprecated and will be removed, please use the new API '
            'register_module(name=None, force=False, module=None) instead.')
        if cls is None:
            return partial(self.deprecated_register_module, force=force)
        self._register_module(cls, force=force)
        return cls

    def register_module(self, name=None, force=False, module=None):
        """Register a module.

        A record will be added to `self._module_dict`, whose key is the class
        name or the specified name, and value is the class itself.
        It can be used as a decorator or a normal function.

        Example:
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> @backbones.register_module(name='mnet')
            >>> class MobileNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)

        Args:
            name (str | None): The module name to be registered. If not
                specified, the class name will be used.
            force (bool, optional): Whether to override an existing class with
                the same name. Default: False.
            module (type): Module class to be registered.
        """
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {type(force)}')
        # NOTE: This is a walkaround to be compatible with the old api,
        # while it may introduce unexpected bugs.
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)

        # use it as a normal method: x.register_module(module=SomeClass)
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # raise the error ahead of time
        if not (name is None or isinstance(name, str)):
            raise TypeError(f'name must be a str, but got {type(name)}')

        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

三、build_dataset(mmdet/datasets/builder)

build_dataset 也类似,通过调用 build_from_cfg 创建。

def build_dataset(cfg, default_args=None):
    from .dataset_wrappers import (ConcatDataset, RepeatDataset,
                                   ClassBalancedDataset)
    if isinstance(cfg, (list, tuple)):
        dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
    elif cfg['type'] == 'RepeatDataset':
        dataset = RepeatDataset(
            build_dataset(cfg['dataset'], default_args), cfg['times'])
    elif cfg['type'] == 'ClassBalancedDataset':
        dataset = ClassBalancedDataset(
            build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
    elif isinstance(cfg.get('ann_file'), (list, tuple)):
        dataset = _concat_dataset(cfg, default_args)
    else:
        dataset = build_from_cfg(cfg, DATASETS, default_args)

    return dataset

四、train_detector(mmdet/apis/train.py)

train_detector 的主要流程为:

(一)构建 data loaders:

data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # cfg.gpus will be ignored if distributed
            len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed) for ds in dataset
    ]

(二)构建分布式处理对象:

model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)

(三)构建优化器:

optimizer = build_optimizer(model, cfg.optimizer)

(四)创建 EpochBasedRunner 并进行训练:

runner = EpochBasedRunner(
        model,
        optimizer=optimizer,
        work_dir=cfg.work_dir,
        logger=logger,
        meta=meta)

我们来看一下源码:

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):
    # 获取 logger
    logger = get_root_logger(cfg.log_level)

    # ==================== 构建 data loaders ====================
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

    # 获得 samples_per_gpu
    if 'imgs_per_gpu' in cfg.data:
        logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
                       'Please use "samples_per_gpu" instead')
        if 'samples_per_gpu' in cfg.data:
            logger.warning(
                f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
                f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
                f'={cfg.data.imgs_per_gpu} is used in this experiments')
        else:
            logger.warning(
                'Automatically set "samples_per_gpu"="imgs_per_gpu"='
                f'{cfg.data.imgs_per_gpu} in this experiments')
        cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu

    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # cfg.gpus will be ignored if distributed
            len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed) for ds in dataset
    ]


    # ==================== 构建分布式处理对象 =====================
    # 如果是多卡会进入此 if
    if distributed:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)
    # 单卡进入
    else:
        model = MMDataParallel(
            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)


    # ====================== 构建优化器 ==========================
    optimizer = build_optimizer(model, cfg.optimizer)

    # ============= 创建 EpochBasedRunner 并进行训练 ==============
    runner = EpochBasedRunner(
        model,
        optimizer=optimizer,
        work_dir=cfg.work_dir,
        logger=logger,
        meta=meta)
    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))
    if distributed:
        runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
        val_dataloader = build_dataloader(
            val_dataset,
            samples_per_gpu=1,
            workers_per_gpu=cfg.data.workers_per_gpu,
            dist=distributed,
            shuffle=False)
        eval_cfg = cfg.get('evaluation', {})
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

在本篇文章(一)中,主要讲解了,train.py 中的主要流程,train.py 中的重要的函数以及函数的具体实现。但是 train_detector 只讲了流程,并没有拆开详细讲解。在下一小结中我们会详细讲解 train_detector 的每一步究竟做了什么。

参考:
https://zhuanlan.zhihu.com/p/163747610

相关标签: mmdetection