TSN源码分析
论文链接:https://arxiv.org/abs/1608.00859
代码链接:https://github.com/yjxiong/tsn-pytorch
1 源码准备
在指定文件夹下,输入命令:
git clone --recursive https://github.com/yjxiong/tsn-pytorch
下载完成后,得到tsn-pytorch源码。
2 源码结构
下表列出tsn-pytorch中比较重要的文件:
文件名称 | 功能 |
---|---|
main.py | 训练脚本 |
test_models.py | 测试脚本 |
opts.py | 参数配置脚本 |
dataset.py | 数据读取脚本 |
models.py | 网络结构构建脚本 |
transforms.py | 数据预处理相关的脚本 |
tf_model_zoo文件夹 | 导入模型结构的脚本 |
接下来对一些重要文件,将一一讲解,并且说清数据流的走向和函数调用关系。
3. 源码分析
3.1 数据准备
dataset.py的主要功能就是对数据集进行读取,并且对其稀疏采样,返回稀疏采样后得到的数据集。
它首先定义了一个类TSNDataSet,用来处理最原始的数据。该类返回的是torch.utils.data.Dataset类型,(注:一般而言在pytorch中自定义的数据读取类都要继承torch.utils.DataSet这个基类),然后通过重写_init_和_getitem_方法来读取函数。
1._init_函数
_init_函数的功能在于初始化TSNDataSet,并设置一些参数和参数默认值。
def __init__(self, root_path, list_file,
num_segments=3, new_length=1, modality='RGB',
image_tmpl='img_{:05d}.jpg', transform=None,
force_grayscale=False, random_shift=True, test_mode=False):
self.root_path = root_path
self.list_file = list_file
self.num_segments = num_segments
self.new_length = new_length
self.modality = modality
self.image_tmpl = image_tmpl
self.transform = transform
self.random_shift = random_shift
self.test_mode = test_mode
if self.modality == 'RGBDiff':
self.new_length += 1# Diff needs one more image to calculate diff
self._parse_list()
TSNDataSet类的初始化方法_init_需要如下参数:
- root_path : 项目的根目录地址,如果其他文件地址使用绝对地址,则可以写成" "
- list_file : 训练或测试的列表文件(.txt文件)地址
- num_segments : 视频分割的段数
- new_length : 根据输入数据集类型的不同,new_length取不同的值
- modality : 输入数据集类型(RGB、光流、RGB差异)
- image_tmpl : 图片的名称
- transform : 数据集是否进行变换操作
- random_shift : 稀疏采样时是否增加一个随机数
- test_mode : 是否是测试时的数据集输入
2._parse_list函数
_parse_list函数功能在于读取list文件,储存在video_list中
def _parse_list(self):
self.video_list = [VideoRecord(x.strip().split(' ')) for x in open(self.list_file)]
self.video_list是一个长度为训练数据数量的列表。每个值都是VIDEORecord对象,包含一个列表和3个属性,列表长度为3,分别为帧路径、该视频含有多少帧和帧标签。
3._sample_indices函数
_sample_indices函数功能在于实现TSN的稀疏采样,返回的是稀疏采样的帧数列表
def _sample_indices(self, record):
average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
if average_duration > 0:
offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments)
elif record.num_frames > self.num_segments:
offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
else:
offsets = np.zeros((self.num_segments,))
return offsets + 1
上一篇: HTML和CSS重难点知识点总结
下一篇: JAVA第三章习题代码之总结