欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

mmclassification源码阅读(四) 模型加载过程

程序员文章站 2022-06-22 23:06:55
以训练过程为例,执行以下脚本。python tools/train.py configs/cifar10/resnet50.py --resume-from=work_dirs/resnet50/epoch_20.pth1、整体流程首先加载配置,args为用户输入参数,cfg为配置文件配置参数。只有将参数统一合并至cfg管理。args = parse_args()cfg = Config.fromfile(args.config)... # 参数合并,预处理过程2、cfg参数解析...

以训练过程为例,执行以下脚本。

python tools/train.py configs/cifar10/resnet50.py --resume-from=work_dirs/resnet50/epoch_20.pth

1、整体流程

首先加载配置,args为用户输入参数,cfg为配置文件配置参数。只有将参数统一合并至cfg管理。

args = parse_args()
cfg = Config.fromfile(args.config)
... # 参数合并,预处理过程

2、cfg参数解析

经过mmcv.Config解析后的变量以类对象形式存在,变量保存在保护成员变量_cfg_dict中,访问时直接采用cfg.key形式访问。其中:

1、args参数,放在最外层。如:

args = {'gpu_ids': range(0, 1), 'work_dir': './work_dirs/resnet50', ...}

解析至cfg后为:

cfg = {'gpu_ids': range(0, 1), 'work_dir': './work_dirs/resnet50', ...}

2、配置文件解析,configs/cifar10/resnet50.py解析过程:

# file: configs/cifar10/resnet50.py
_base_ = [
    '../_base_/models/resnet50_cifar.py',  # 1、模型结构配置
    '../_base_/datasets/cifar10.py',  # 2、数据集配置
    '../_base_/schedules/cifar10.py',   # 3、训练参数配置
        '../_base_/default_runtime.py'  # 4、运行时,模型保存,显卡等配置
]

解析__base__内文件配置,对文件内容依次展开,合并放至cfg最外层。如运行时default_runtime.py文件参数为:

# file: ../_base_/models/resnet50_cifar.py,  model settings
model = dict(  # 1字典
    type='ImageClassifier',  # 1类型 -注册器:CLASSIFIERS 
    backbone=dict(  # 1.1字典
        type='ResNet_CIFAR',  # 1.1类型  -注册器:BACKBONES 
        depth=50,  # **kwargs
        num_stages=4,# **kwargs
        out_indices=(3, ),# **kwargs
        style='pytorch'),# **kwargs
    neck=dict(  # 1.2字典
        type='GlobalAveragePooling'), # 1.2类型  -注册器:NECKS 
    head=dict(  # 1.3字典
        type='LinearClsHead', # 1.3类型  -注册器:HEADS 
        num_classes=10,  # **kwargs
        in_channels=2048,  # **kwargs
        loss=dict(  # 2..1字典
            type='CrossEntropyLoss', # 2.1类型  -注册器:LOSSES 
            loss_weight=1.0),  # **kwargs
        )
)

解析至cfg后为:

cfg = {..., 'model': xx, ...}

3、构建模型结构

执行代码:

model = build_classifier(cfg.model)

model字典为两级结构,第一层为任务类型type和四个基本结构(backbone、neck、head、loss),第二层为各个结构配置参数,字典型,包含type和额外参数**kwargs。每一个type都对应一种注册器,按层级结构递归调用注册器进行模型结构生成。注册器在mmcls/models/builder.py中创建,models下对应的包中进行注册。

# file: mmcls/models/builder.py
BACKBONES = Registry('backbone')
CLASSIFIERS = Registry('classifier')
HEADS = Registry('head')
NECKS = Registry('neck')
LOSSES = Registry('loss')

step1、build_classifier函数调用

build_classifier函数传参CLASSIFIERS注册器,最后调用mmcv.utils.registry.py中的build_from_cfg(cfg, registry, default_args)函数。

def build_classifier(cfg):
    return build(cfg, CLASSIFIERS)
    
def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)    

build_from_cfg函数实现:

# file: mmcv.utils.registry.py
def build_from_cfg(cfg, registry, default_args=None):
     ...     
    args = cfg.copy()
    obj_type = args.pop('type')  # 1、获取dict中type字段(注册器名称)'ImageClassifier'
    if is_str(obj_type):
        obj_cls = registry.get(obj_type)  # 2、获取注册器内对应的类
        ...
 return obj_cls(**args)  # 3、实例化类对象,type以外的参数作为关键字参数传递于类初始化-'ImageClassifier'类

实际加载过程:

调用CLASSIFIERS('classifier')注册器,加载其中的ImageClassifier类,type以外的参数作为关键字参数,实例化ImageClassifier类。

step2、ImageClassifier类实例化

ImageClassifier构造函数如下,接受关键字参数backbone、neck、head,loss未声明此忽略, pretrained=None。继承自BaseClassifier,首先调用基类BaseClassifier构造函数。之后依次加载self.backbone、self.neck、self.head、及权重初始化。

@CLASSIFIERS.register_module()
class ImageClassifier(BaseClassifier):

    def __init__(self, backbone, neck=None, head=None, pretrained=None):
        super(ImageClassifier, self).__init__()
        self.backbone = build_backbone(backbone)

        if neck is not None:
            self.neck = build_neck(neck)

        if head is not None:
            self.head = build_head(head)

        self.init_weights(pretrained=pretrained)

step3、self.backbone、self.neck、self.head类实例化

backbone构造,调用BACKBONES('backbone')注册器,加载其中ResNet_CIFAR类,type以外的参数作为关键字参数,实例化ResNet_CIFAR类。(继承自基类ResNet,实际上基类中实现初始化,子类中覆盖_make_stem_layer、forward函数)。

neck构造,调用NECKS('backneckbone')注册器,加载其中'GlobalAveragePooling'类,type以外的参数作为关键字参数,实例化GlobalAveragePooling类。

head构造,调用HEADS('head')注册器,加载其中'LinearClsHead'类,type以外的参数作为关键字参数,实例化LinearClsHead类。LinearClsHead类接收loss关键字参数。

# file: mmcls/models/heads/cls_head.py
@HEADS.register_module()
class ClsHead(BaseHead):
    def __init__(self,
                 loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
                 topk=(1, )):
        super(ClsHead, self).__init__()
        ...
        
        self.compute_loss = build_loss(loss)
        self.compute_accuracy = Accuracy(topk=self.topk)

loss构造,调用LOSSES('loss')注册器,加载其中'CrossEntropyLoss'类,type以外的参数作为关键字参数,实例化CrossEntropyLoss类。loss类实例化对象存在与self.head对象中。

传送门:mmclassification项目阅读系列文章目录

源码阅读:

1、setup.py工程环境配置(一)

2、mmcls库组织结构说明(二)

3、registry类注册机制(三)

4、模型加载过程(四)

5、数据加载过程(五)

6、train_model执行过程(六)

本文地址:https://blog.csdn.net/weixin_34910922/article/details/107886943