mmdetection2.0 | train.py相关源码详解(一)
文章目录
本文章从源码入手,详细的剖析 MMDetection 的训练细节。理解它是如何实现通用,灵活的训练的。
MMDetection 支持单机单卡、多卡训练。而且对于训练部分的代码和模型、数据集的代码的耦合性极低。真正实现了支持各种训练模式、各种模型和数据集的通用训练模板。
一、tools/train.py
tools/train.py
是单机单卡训练
时需要运行的文件,其主要使用方法为:
python tools/train.py ${CONFIG_FILE} [optional arguments]
如果为单机多卡
,需要运行 tools/dist_train.sh
,注意此文件不支持多机多卡训练:
./tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]
可选参数如下:
=========== optional arguments ===========
# --work-dir 存储日志和模型的目录
# --resume-from 加载 checkpoint 的目录
# --no-validate 是否在训练的时候进行验证
# 互斥组:
# --gpus 使用的 GPU 数量
# --gpu_ids 使用指定 GPU 的 id
# --seed 随机数种子
# --deterministic 是否设置 cudnn 为确定性行为
# --options 其他参数
# --launcher 分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
# none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
# --local_rank 本地进程编号,此参数 torch.distributed.launch 会自动传入。
我们来看一下 dist_train.sh 里面究竟是什么?
可以看出 dist_train.sh 的本质就是使用 torch.distributed.launch(这是分布式的辅助启动工具) 运行 tools/train.py。
torch.distributed.launch 需要使用 python -m 来运行,-m 是把一个模块当做脚本来运行的一个参数。一般情况下可以给 torch.distributed.launch 传如下的参数
--nproc_per_node:表示每台机器的 GPU 数量
--nnodes:表示机器的数量
--node_rank:机器的排名,如果为 0 代表是 master 节点(机器)。
--master_addr:master 节点的 IP 地址
--master_port:master 节点开放的端口号
可以看到,因为 dist_train.sh 只支持单机多卡训练,所以只传参了 --nproc_per_node(GPU 个数)和 --master_port(开放的端口号)
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
PORT=${PORT:-29500}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
对于 tools/train.py 其主要的流程为:
(一)从命令行和配置文件获取参数配置
(二)构建模型
# 构建模型: 需要传入 cfg.model,cfg.train_cfg,cfg.test_cfg
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
(三)构建数据集
# 构建数据集: 需要传入 cfg.data.train,表明是训练集
datasets = [build_dataset(cfg.data.train)]
(四)训练模型
# 训练检测器:需要传入模型、数据集、配置参数等
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
所以对于 train.py
来说,首先从命令行和配置文件读取配置,然后分别用 build_detector
、build_dataset
构建模型和数据集,最后将模型和数据集传入 train_detector
进行训练。
下面我们来看一下源码:
import argparse
import copy
import os
import os.path as osp
import time
import warnings
import mmcv
import torch
# Config 用于读取配置文件, DictAction 将命令行字典类型参数转化为 key-value 形式
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash
from mmdet import __version__
from mmdet.apis import set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger
# python tools/train.py ${CONFIG_FILE} [optional arguments]
# =========== optional arguments ===========
# --work-dir 存储日志和模型的目录
# --resume-from 加载 checkpoint 的目录
# --no-validate 是否在训练的时候进行验证
# 互斥组:
# --gpus 使用的 GPU 数量
# --gpu_ids 使用指定 GPU 的 id
# --seed 随机数种子
# --deterministic 是否设置 cudnn 为确定性行为
# --options 其他参数
# --launcher 分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
# none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
# --local_rank 本地进程编号,此参数 torch.distributed.launch 会自动传入。
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
# action: store (默认, 表示保存参数)
# action: store_true, store_false (如果指定参数, 则为 True, False)
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
# --------- 创建一个互斥组. argparse 将会确保互斥组中的参数只能出现一个 ---------
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
# 可以使用 python train.py --gpu-ids 0 1 2 3 指定使用的 GPU id
# 参数结果:[0, 1, 2, 3]
# nargs = '*':参数个数可以设置0个或n个
# nargs = '+':参数个数可以设置1个或n个
# nargs = '?':参数个数可以设置0个或1个
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed training)')
# ------------------------------------------------------------------------
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
# 其他参数: 可以使用 --options a=1,2,3 指定其他参数
# 参数结果: {'a': [1, 2, 3]}
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
# 如果使用 dist_utils.sh 进行分布式训练, launcher 默认为 pytorch
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
# 本地进程编号,此参数 torch.distributed.launch 会自动传入。
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
# 如果环境中没有 LOCAL_RANK,就设置它为当前的 local_rank
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.cfg_options:
raise ValueError(
'--options and --cfg-options cannot be both '
'specified, --options is deprecated in favor of --cfg-options')
if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options')
args.cfg_options = args.options
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config) # 从文件读取配置
# 从命令行读取额外的配置
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark,设置True 可以加速输入大小固定的模型. 如:SSD300
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
# work_dir 的优先程度为: 命令行 > 配置文件
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
# 当 work_dir 为 None 的时候, 使用 ./work_dir/配置文件名 作为默认工作目录
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
# os.path.basename(path) 返回文件名
# os.path.splitext(path) 分割路径, 返回路径名和文件扩展名的元组
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
# 是否继续上次的训练
if args.resume_from is not None:
cfg.resume_from = args.resume_from
# gpu id
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
# init distributed env first, since logger depends on the dist info.
# 如果 launcher 为 none,不启用分布式训练。不使用 dist_train.sh 默认参数为 none.
if args.launcher == 'none':
distributed = False
# launcher 不为 none,启用分布式训练。使用 dist_train.sh,会传 ‘pytorch’
else:
distributed = True
# 初始化 dist 里面会调用 init_process_group
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# 设置随机化种子
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
# 构建模型: 需要传入 cfg.model, cfg.train_cfg, cfg.test_cfg
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
# 构建数据集: 需要传入 cfg.data.train
datasets = [build_dataset(cfg.data.train)]
# workflow 代表流程:
# [('train', 2), ('val', 1)] 就代表,训练两个 epoch 验证一个 epoch
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__ + get_git_hash()[:7],
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
# 训练检测器, 传入:模型, 数据集, config 等
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
if __name__ == '__main__':
main()
我们分别来看看与 train.py 相关的核心函数:
1、init_dist:
此函数负责调用 init_process_group,完成分布式的初始化
。在运行 dist_train.py
训练时,默认传递的 launcher 是 ‘pytorch’。所以此函数会进一步调用 _init_dist_pytorch 来完成初始化。
因为 torch.distributed 可以采用单进程控制多 GPU,也可以一个进程控制一个 GPU
。一个进程控制一个 GPU 是目前Pytorch中,无论是单节点还是多节点,进行数据并行训练最快的方式。在 mmdet 中也是这么实现的。既然是单个进程控制单个 GPU,那么我么就需要绑定当前进程控制的是哪个 GPU。可以理解为在使用 torch.distributed.launch 运行 py 文件时。 它会多次调用 py 文件,每个 py 文件控制一个 GPU。并向每个 py 文件传参 --local_rank。(local_rank 是在这台机器上的本地进程编号)这样对于每个 py 文件,都能拿到传入的本地进程编号,我们只需要把当前进程绑定到指定的 GPU 即可。
在 _init_dist_pytorch 中就会设置当前进程控制的默认 GPU(torch.cuda.set_device),再使用 dist.init_process_group 初始化,初始化的方式为默认的 env://,即环境变量的方式。使用 env:// 方式初始化就需要用 torch.distributed.launch 运行 py 文件,torch.distributed.launch 会根据传入的参数设置环境变量,并运行 py 文件。
# Copyright (c) Open-MMLab. All rights reserved.
import functools
import os
import subprocess
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from mmcv.utils import TORCH_VERSION
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
# 默认进到这里
if launcher == 'pytorch':
# 调用下面的 _init_dist_pytorch 函数来初始化。
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend, **kwargs):
# rank 是所有进程的总编号,算上本地进程,从 0 开始。
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
# 这里也可以使用命令行传递来的 --local_rank。
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
2、set_random_seed:
此函数会对 python、numpy、torch 都设置随机数种子。
保持随机数种子相同时,卷积的结果在CPU上相同,在GPU上仍然不相同。这是因为,cudnn卷积行为的不确定性。使用 torch.backends.cudnn.deterministic = True 可以解决。
cuDNN 使用非确定性算法,并且可以使用 torch.backends.cudnn.enabled = False 来进行禁用。如果设置为 torch.backends.cudnn.enabled = True,说明设置为使用非确定性算法(即会自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题)
一般来讲,应该遵循以下准则:
- 如果网络的输入数据维度或类型上变化不大,设置torch.backends.cudnn.benchmark = true 可以增加运行效率
- 如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率。设置
torch.backends.cudnn.benchmark = False 避免重复搜索。
def set_random_seed(seed, deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# manual_seed_all 是为所有 GPU 都设置随机数种子。
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
3、get_root_logger:
get_root_logger 调用 get_logger 函数获取 logger 对象。
import logging
from mmcv.utils import get_logger
def get_root_logger(log_file=None, log_level=logging.INFO):
logger = get_logger(name='mmdet', log_file=log_file, log_level=log_level)
return logger
这里实现的 get_logger 函数非常灵活,如果传入相同的 log 的 name,会返回配置相同的 log。传入以点分割的日志名称的子模块,也会返回相同的 log。如:a 和 a.b 会返回相同的 log。
如果传入 log_file 会保存 log 的输出到 log_file 指定的路径,如果不传入 log_file,不保存日志的输出。只在控制台输出。
下面我们来分析一下源码:
import logging
import torch.distributed as dist
# 记录是否创建过 name 对应的 log,如果创建过设置为 True
logger_initialized = {}
def get_logger(name, log_file=None, log_level=logging.INFO):
# 获取 log 对象。
logger = logging.getLogger(name)
# 如果已经创建过,直接返回
if name in logger_initialized:
return logger
# 如果是创建过的以 ‘.’ 分割的子模块,也直接返回
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
# 获取当前的 rank(总进程编号)
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
# 只有 rank 0(master 节点的 local_rank 为 0 的进程)的主机才保存日志
if rank == 0 and log_file is not None:
file_handler = logging.FileHandler(log_file, 'w')
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
# 对于非 rank 为 0 的进程,只有 error 以上的信息才会显示
if rank == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
# 将 log name 对应的值设为 True,表示创建过。
logger_initialized[name] = True
return logger
因为在 train.py 中主要调用:构建模型(build_detector),构建数据集(build_dataset),训练模型(train_detector)的函数,我们下来分别看看这三个函数的源码。
二、build_detector(mmdet/models/builder.py)
build_detector
函数将配置文件config中的:model、train_cfg 和 test_cfg
三部分传入参数。
下面以 faster_rcnn_r50_fpn_1x_coco.py
配置文件来举例:
具体在faster_rcnn_r50_fpn.py
文件中
model
model = dict(
type='FasterRCNN',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0))))
train_cfg
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
)
运行时会将上面的三个值作为参数传入 build_detector
函数,build_detector 函数会调用 build
函数,build 函数调用 build_from_cfg
函数构建检测器对象。其中 train_cfg
和 test_cfg
作为默认参数用于构建 detector 对象。
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
# 调用 build_from_cfg 用来根据 config 字典构建 registry 里面的对象
return build_from_cfg(cfg, registry, default_args)
def build_detector(cfg, train_cfg=None, test_cfg=None):
# 调用 build 函数,传入 cfg, registry 对象,
# 把 train_cfg 和 test_cfg 作为默认字典传入
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
build_from_cfg
在 mmcv/utils/registery.py
中。其中参数 cfg 字典中的 type 键所对应的值表示需要创建的对象的类型。build_from_cfg 会自动在 Registry 注册的类中找到需要创建的类,并传入默认参数实例化。
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
raise KeyError(
f'the cfg dict must contain the key "type", but got {cfg}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
# 获取 type 对应的值
obj_type = args.pop('type')
if is_str(obj_type):
# 获取需要创建的对象
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
# 如果 default_args 不是 None,传入默认值再实例化。
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)
那么什么是 registry?
registry 就是注册类,将一个字符串和类关联起来。如果索引字符串就会获得类。Registry 是注册所需要的类,可以用它来注册类。我们可以使用如下的方式来注册类。
backbones = Registry('backbone')
@backbones.register_module()
class ResNet:
pass
backbones = Registry('backbone')
@backbones.register_module(name='mnet')
class MobileNet:
pass
backbones = Registry('backbone')
class ResNet:
pass
backbones.register_module(ResNet)
下面是 Registry 类的代码,它的内部维护了一个已经注册的类的字典 ——_module_dict。每当注册一个类就在字典里添加一个字符串(默认为类名)与类的映射。register_module 方法,利用装饰器将类名和类添加到 _module_dict 中。对于注册的模块可以通过 build_from_cfg 来构建。
import inspect
import warnings
from functools import partial
from .misc import is_str
class Registry:
"""A registry to map strings to classes.
Args:
name (str): Registry name.
"""
def __init__(self, name):
self._name = name
# 已经注册的类的字典
self._module_dict = dict()
def __len__(self):
return len(self._module_dict)
def __contains__(self, key):
return self.get(key) is not None
def __repr__(self):
format_str = self.__class__.__name__ + \
f'(name={self._name}, ' \
f'items={self._module_dict})'
return format_str
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
return self._module_dict.get(key, None)
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
self._module_dict[module_name] = module_class
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.')
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'name must be a str, but got {type(name)}')
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
三、build_dataset(mmdet/datasets/builder)
build_dataset
也类似,通过调用 build_from_cfg
创建。
def build_dataset(cfg, default_args=None):
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
ClassBalancedDataset)
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif cfg['type'] == 'ClassBalancedDataset':
dataset = ClassBalancedDataset(
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
elif isinstance(cfg.get('ann_file'), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
四、train_detector(mmdet/apis/train.py)
train_detector
的主要流程为:
(一)构建 data loaders:
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed) for ds in dataset
]
(二)构建分布式处理对象:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
(三)构建优化器:
optimizer = build_optimizer(model, cfg.optimizer)
(四)创建 EpochBasedRunner 并进行训练:
runner = EpochBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)
我们来看一下源码:
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
# 获取 logger
logger = get_root_logger(cfg.log_level)
# ==================== 构建 data loaders ====================
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# 获得 samples_per_gpu
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed) for ds in dataset
]
# ==================== 构建分布式处理对象 =====================
# 如果是多卡会进入此 if
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
# 单卡进入
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# ====================== 构建优化器 ==========================
optimizer = build_optimizer(model, cfg.optimizer)
# ============= 创建 EpochBasedRunner 并进行训练 ==============
runner = EpochBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)
# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
eval_cfg = cfg.get('evaluation', {})
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
在本篇文章(一)中,主要讲解了,train.py 中的主要流程,train.py 中的重要的函数以及函数的具体实现。但是 train_detector 只讲了流程,并没有拆开详细讲解。在下一小结中我们会详细讲解 train_detector 的每一步究竟做了什么。
参考:
https://zhuanlan.zhihu.com/p/163747610