mmclassification源码阅读(四) 模型加载过程
以训练过程为例,执行以下脚本。
python tools/train.py configs/cifar10/resnet50.py --resume-from=work_dirs/resnet50/epoch_20.pth
1、整体流程
首先加载配置,args为用户输入参数,cfg为配置文件配置参数。只有将参数统一合并至cfg管理。
args = parse_args()
cfg = Config.fromfile(args.config)
... # 参数合并,预处理过程
2、cfg参数解析
经过mmcv.Config解析后的变量以类对象形式存在,变量保存在保护成员变量_cfg_dict中,访问时直接采用cfg.key形式访问。其中:
1、args参数,放在最外层。如:
args = {'gpu_ids': range(0, 1), 'work_dir': './work_dirs/resnet50', ...}
解析至cfg后为:
cfg = {'gpu_ids': range(0, 1), 'work_dir': './work_dirs/resnet50', ...}
2、配置文件解析,configs/cifar10/resnet50.py解析过程:
# file: configs/cifar10/resnet50.py
_base_ = [
'../_base_/models/resnet50_cifar.py', # 1、模型结构配置
'../_base_/datasets/cifar10.py', # 2、数据集配置
'../_base_/schedules/cifar10.py', # 3、训练参数配置
'../_base_/default_runtime.py' # 4、运行时,模型保存,显卡等配置
]
解析__base__内文件配置,对文件内容依次展开,合并放至cfg最外层。如运行时default_runtime.py文件参数为:
# file: ../_base_/models/resnet50_cifar.py, model settings
model = dict( # 1字典
type='ImageClassifier', # 1类型 -注册器:CLASSIFIERS
backbone=dict( # 1.1字典
type='ResNet_CIFAR', # 1.1类型 -注册器:BACKBONES
depth=50, # **kwargs
num_stages=4,# **kwargs
out_indices=(3, ),# **kwargs
style='pytorch'),# **kwargs
neck=dict( # 1.2字典
type='GlobalAveragePooling'), # 1.2类型 -注册器:NECKS
head=dict( # 1.3字典
type='LinearClsHead', # 1.3类型 -注册器:HEADS
num_classes=10, # **kwargs
in_channels=2048, # **kwargs
loss=dict( # 2..1字典
type='CrossEntropyLoss', # 2.1类型 -注册器:LOSSES
loss_weight=1.0), # **kwargs
)
)
解析至cfg后为:
cfg = {..., 'model': xx, ...}
3、构建模型结构
执行代码:
model = build_classifier(cfg.model)
model字典为两级结构,第一层为任务类型type和四个基本结构(backbone、neck、head、loss),第二层为各个结构配置参数,字典型,包含type和额外参数**kwargs。每一个type都对应一种注册器,按层级结构递归调用注册器进行模型结构生成。注册器在mmcls/models/builder.py中创建,models下对应的包中进行注册。
# file: mmcls/models/builder.py
BACKBONES = Registry('backbone')
CLASSIFIERS = Registry('classifier')
HEADS = Registry('head')
NECKS = Registry('neck')
LOSSES = Registry('loss')
step1、build_classifier函数调用
build_classifier函数传参CLASSIFIERS注册器,最后调用mmcv.utils.registry.py中的build_from_cfg(cfg, registry, default_args)函数。
def build_classifier(cfg):
return build(cfg, CLASSIFIERS)
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
build_from_cfg函数实现:
# file: mmcv.utils.registry.py
def build_from_cfg(cfg, registry, default_args=None):
...
args = cfg.copy()
obj_type = args.pop('type') # 1、获取dict中type字段(注册器名称)'ImageClassifier'
if is_str(obj_type):
obj_cls = registry.get(obj_type) # 2、获取注册器内对应的类
...
return obj_cls(**args) # 3、实例化类对象,type以外的参数作为关键字参数传递于类初始化-'ImageClassifier'类
实际加载过程:
调用CLASSIFIERS('classifier')注册器,加载其中的ImageClassifier类,type以外的参数作为关键字参数,实例化ImageClassifier类。
step2、ImageClassifier类实例化
ImageClassifier构造函数如下,接受关键字参数backbone、neck、head,loss未声明此忽略, pretrained=None。继承自BaseClassifier,首先调用基类BaseClassifier构造函数。之后依次加载self.backbone、self.neck、self.head、及权重初始化。
@CLASSIFIERS.register_module()
class ImageClassifier(BaseClassifier):
def __init__(self, backbone, neck=None, head=None, pretrained=None):
super(ImageClassifier, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
if head is not None:
self.head = build_head(head)
self.init_weights(pretrained=pretrained)
step3、self.backbone、self.neck、self.head类实例化
backbone构造,调用BACKBONES('backbone')注册器,加载其中ResNet_CIFAR类,type以外的参数作为关键字参数,实例化ResNet_CIFAR类。(继承自基类ResNet,实际上基类中实现初始化,子类中覆盖_make_stem_layer、forward函数)。
neck构造,调用NECKS('backneckbone')注册器,加载其中'GlobalAveragePooling'类,type以外的参数作为关键字参数,实例化GlobalAveragePooling类。
head构造,调用HEADS('head')注册器,加载其中'LinearClsHead'类,type以外的参数作为关键字参数,实例化LinearClsHead类。LinearClsHead类接收loss关键字参数。
# file: mmcls/models/heads/cls_head.py
@HEADS.register_module()
class ClsHead(BaseHead):
def __init__(self,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, )):
super(ClsHead, self).__init__()
...
self.compute_loss = build_loss(loss)
self.compute_accuracy = Accuracy(topk=self.topk)
loss构造,调用LOSSES('loss')注册器,加载其中'CrossEntropyLoss'类,type以外的参数作为关键字参数,实例化CrossEntropyLoss类。loss类实例化对象存在与self.head对象中。
传送门:mmclassification项目阅读系列文章目录
源码阅读:
本文地址:https://blog.csdn.net/weixin_34910922/article/details/107886943