欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

mmaction2 数据相关源码概览

程序员文章站 2023-12-11 21:10:10
文章目录0. 前言1. Dataset构建过程2. 数据预处理模块化实现3. 视频采样方式实现4. 数据增强方式实现5. DataLoader 的实现0. 前言github: open-mmlab/mmaction2从宏观角度记录一下与 mmaction 还是非常类似的,做了一定的优化,暂时不支持 ava。Dataset 类型包括了 RawframeDataset, VideoDataset, ActivityNetDataset 三个基本类型,前两者是行为识别数据集,最后一个是时序行为检测...


0. 前言

  • github: open-mmlab/mmaction2
  • 从宏观角度记录一下
    • 与 mmaction 还是非常类似的,做了一定的优化,暂时不支持 ava。
    • Dataset 类型包括了 RawframeDataset, VideoDataset, ActivityNetDataset 三个基本类型,前两者是行为识别数据集,最后一个是时序行为检测数据集。
    • 对于每个具体数据集(如Kinetics、SomethingSomething等),先通过预处理转换为统一的格式,然后再通过上面三种数据集基本类型,用于后续的模型训练/预测/评估。
    • 数据预处理(包括读取图片帧、视频解码、帧采样、数据增强、均值/方差处理等)都模块化了,源码看起来非常舒服,增加新功能也更方便了。

1. Dataset构建过程

  • 基本机制:通过配置文件选择基本数据集(RawframeDataset, VideoDataset, ActivityNetDataset等)的类型与参数。
  • 入口函数:mmaction.datasets.builder.py 中的 build_dataset 方法。
  • 相关配置文件:cfg.data.train/val/test
  • 配置文件举例
data = dict(
    videos_per_gpu=8,
    workers_per_gpu=4,
    train=dict(
        type=dataset_type,
        ann_file=ann_file_train,
        data_prefix=data_root,
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=ann_file_val,
        data_prefix=data_root_val,
        pipeline=val_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=ann_file_val,
        data_prefix=data_root_val,
        pipeline=test_pipeline))
  • 构建过程:
    • 以训练集为例,首先通过cfg.data.train.type选择基本数据集类型(假设是RawframeDataset类型),cfg.data.train中的其他参数就是基本数据集构造函数的参数。
    • 如何通过 cfg.data.train.type 选择基本数据集类型:
      • 其实就是一个Register机制。
      • 所有基本数据集都注册到一个名为DATASETmmcv.utils.Registry对象中,该对象中维护了一个字典,该字典 key 为基本数据集类别名称,value 为基本数据集类型。
      • 上述注册过程通过注解实现。

2. 数据预处理模块化实现

  • 所有基本数据集类型都继承自 mmaction.datasets.base.BaseDataset
    • 该类中定义了一个 self.pipeline 成员变量。
    • 该成员变量就是所有模块化数据预处理的集合。
    • 该成员变量用于后续数据读取过程。
  • pipeline的构建过程:将配置文件中定义的内容转换为一个list,并将该list传入mmaction.datasets.pipelines.compose.Compose 中。
  • pipeline 配置文件概述:
    • 对于train/val/test的pipeline各自都是一个list。
    • list中每个成员就是一个pipeline配置。
    • 每个pipeline配置都包含type用于指定pipeline类型,其他参数就是该类型的初始化参数。
    • 配置文件举例
train_pipeline = [
    dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1),
    dict(type='FrameSelector', decoding_backend='turbojpeg'),
    dict(type='Resize', scale=(-1, 256), lazy=True),
    dict(
        type='MultiScaleCrop',
        input_size=224,
        scales=(1, 0.8),
        random_crop=False,
        max_wh_scale_gap=0,
        lazy=True),
    dict(type='Resize', scale=(224, 224), keep_ratio=False, lazy=True),
    dict(type='Flip', flip_ratio=0.5, lazy=True),
    dict(type='Fuse'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='FormatShape', input_format='NCTHW'),
    dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
    dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
    dict(
        type='SampleFrames',
        clip_len=32,
        frame_interval=2,
        num_clips=1,
        test_mode=True),
    dict(type='FrameSelector', decoding_backend='turbojpeg'),
    dict(type='Resize', scale=(-1, 256), lazy=True),
    dict(type='CenterCrop', crop_size=224, lazy=True),
    dict(type='Flip', flip_ratio=0, lazy=True),
    dict(type='Fuse'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='FormatShape', input_format='NCTHW'),
    dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
    dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
    dict(
        type='SampleFrames',
        clip_len=32,
        frame_interval=2,
        num_clips=10,
        test_mode=True),
    dict(type='FrameSelector', decoding_backend='turbojpeg'),
    dict(type='Resize', scale=(-1, 256)),
    dict(type='ThreeCrop', crop_size=256),
    dict(type='Flip', flip_ratio=0),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='FormatShape', input_format='NCTHW'),
    dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
    dict(type='ToTensor', keys=['imgs'])
]
  • mmaction.datasets.pipelines.compose.Compose 概述
    • 作用:将上述配置文件,转换为具体组件列表。
    • 配置文件 -> 具体组件,就是一个Registry机制实现。
      • 即每个具体组件通过 PIPELINES 注册。
      • 在构建对象是,通过配置文件中的 typePIPELINES 中选择组件类型,并通过配置文件中其他参数初始化选中的组件。
    • 调用 Compose 类实现的功能就是将数据依次通过所有组件,得到结果。
  • pipeline的分类
    • 数据导入
      • 目标:解析各种类型的数据(视频、图像帧),并进行采样。
      • 源码位于 mmaction.datasets.pipelines.loading.py 中。
    • 数据增强
      • 目标:实现常见的数据增强方式。
      • 源码位于 mmaction.datasets.pipelines.augmentations.py
    • 数据类型/格式转换
      • 目标:将数据转换格式,一般就是转换为torch.Tensor。也包括了Transpose。
      • 源码位于 mmaction.datasets.pipelines.formating.py 中。
    • 上面介绍过的 Compose
      • 作用:将配置文件转换为具体组件实例,融合所有组件,方便其他对象的调用。
      • 源码位于 mmaction.datasets.pipelines.compose.py 中。

3. 视频采样方式实现

  • 属于数据增强组件的一部分,源码位于 mmaction.datasets.pipelines.loading.py 中。
  • 支持的采样方式包括 SampleFramesDenseSampleFrames 两种。
  • SampleFrames
    • 主要参数包括:clip_len, frame_interval, num_clips
    • 两种基本采样方式:
      • TSN形式:将视频分为x个部分,每个部分随机取一帧。
        • clip_len=1, num_clips=x,另外一个参数取值无所谓
      • 普通形式:在连续的帧中,间隔x帧提取帧,一共获取y帧。
        • num_clips=1, clip_len=y, frames_interval=x
  • DenseSampleFrames
    • 继承了 SampleFrames
    • 主要就是使用了 dense sample strategy,即在一个限定范围内获取每个clip的起始帧id。
    • 新增的参数包括 sample_range(期望在 [0, sample_range) 中获取所有clip的起始位置) 和 num_sample_positions(test时使用,用于获取多个 clip 的起始位置)。

4. 数据增强方式实现

  • 训练时支持的主要就是 resize + random crop + flip 这类。
  • 测试时除了上面的之外,还支持 ThreeCrop(resize+上中下/左中右crop)和 TenCrop(左上、左下、右上、右下、正中,以及flip)

5. DataLoader 的实现

  • 主要就是设置了 batch size,分布式训练中使用的 Sampler,以及数据 collate(整理)方式。

本文地址:https://blog.csdn.net/irving512/article/details/107434686

相关标签: PyTorch

上一篇:

下一篇: