欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TSN源码分析

程序员文章站 2022-03-19 11:48:43
...

TSN源码分析
论文链接:https://arxiv.org/abs/1608.00859
代码链接:https://github.com/yjxiong/tsn-pytorch

1 源码准备

在指定文件夹下,输入命令:

git clone --recursive https://github.com/yjxiong/tsn-pytorch 

下载完成后,得到tsn-pytorch源码。

2 源码结构

下表列出tsn-pytorch中比较重要的文件:

文件名称 功能
main.py 训练脚本
test_models.py 测试脚本
opts.py 参数配置脚本
dataset.py 数据读取脚本
models.py 网络结构构建脚本
transforms.py 数据预处理相关的脚本
tf_model_zoo文件夹 导入模型结构的脚本

接下来对一些重要文件,将一一讲解,并且说清数据流的走向和函数调用关系。

3. 源码分析

3.1 数据准备

dataset.py的主要功能就是对数据集进行读取,并且对其稀疏采样,返回稀疏采样后得到的数据集。

它首先定义了一个类TSNDataSet,用来处理最原始的数据。该类返回的是torch.utils.data.Dataset类型,(:一般而言在pytorch中自定义的数据读取类都要继承torch.utils.DataSet这个基类),然后通过重写_init_和_getitem_方法来读取函数。


1._init_函数
_init_函数的功能在于初始化TSNDataSet,并设置一些参数和参数默认值。

def __init__(self, root_path, list_file,
                 num_segments=3, new_length=1, modality='RGB',
                 image_tmpl='img_{:05d}.jpg', transform=None,
                 force_grayscale=False, random_shift=True, test_mode=False):
        self.root_path = root_path
        self.list_file = list_file
        self.num_segments = num_segments
        self.new_length = new_length
        self.modality = modality
        self.image_tmpl = image_tmpl
        self.transform = transform
        self.random_shift = random_shift
        self.test_mode = test_mode

        if self.modality == 'RGBDiff':
            self.new_length += 1# Diff needs one more image to calculate diff

        self._parse_list()

TSNDataSet类的初始化方法_init_需要如下参数:

  • root_path : 项目的根目录地址,如果其他文件地址使用绝对地址,则可以写成" "
  • list_file : 训练或测试的列表文件(.txt文件)地址
  • num_segments : 视频分割的段数
  • new_length : 根据输入数据集类型的不同,new_length取不同的值
  • modality : 输入数据集类型(RGB、光流、RGB差异)
  • image_tmpl : 图片的名称
  • transform : 数据集是否进行变换操作
  • random_shift : 稀疏采样时是否增加一个随机数
  • test_mode : 是否是测试时的数据集输入

2._parse_list函数
_parse_list函数功能在于读取list文件,储存在video_list中

    def _parse_list(self):
        self.video_list = [VideoRecord(x.strip().split(' ')) for x in open(self.list_file)]

self.video_list是一个长度为训练数据数量的列表。每个值都是VIDEORecord对象,包含一个列表和3个属性,列表长度为3,分别为帧路径、该视频含有多少帧和帧标签。

3._sample_indices函数
_sample_indices函数功能在于实现TSN的稀疏采样,返回的是稀疏采样的帧数列表

    def _sample_indices(self, record):
        average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
        if average_duration > 0:
            offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments)
        elif record.num_frames > self.num_segments:
            offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
        else:
            offsets = np.zeros((self.num_segments,))
        return offsets + 1
相关标签: tsn 神经网络