TSN实验过程
论文链接:https://arxiv.org/abs/1608.00859
代码链接:https://github.com/yjxiong/tsn-pytorch
1.数据准备阶段
1.1 数据集介绍
在视频分类项目中,有很多经典的公开数据集,目前主要的数据集如列表所示:
数据集 | 视频数 | 分类数 | 发布年 | 背景 |
---|---|---|---|---|
KTH | 600 | 6 | 2004 | 静态 |
HMDB51 | 6766 | 51 | 2011 | 动态 |
UCF101 | 13320 | 101 | 2012 | 动态 |
THUMOS-2014 | 18394 | 101 | 2014 | 动态 |
ActivityNet | 27901 | 203 | 2015 | 动态 |
YouTube-8M | 8264650 | 4800 | 2016 | 动态 |
本次TSN实验复现使用的是UCF101数据集。
UCF101是动作识别数据集,从Youtube收集而得,共包含101类动作。其中每类动作由25个人做动作,每人做4-7组,共13320个视频,分辨率为320*240。UCF101在动作的采集上具有非常大的多样性,包括相机运行、外观变化、姿态变化、物体比例变化、背景变化、光纤变化等。
101类动作可以分为5类:人与物体互动、人体动作、人与人互动、乐器演奏、体育运动。
1.2 下载数据集
下载网址:http://crcv.ucf.edu/data/UCF101/UCF101.rar
下载成功后的UCF文件夹如下所示:
该文件夹下是各种动作的视频文件,共有101种类别
下图是UCF101在进行训练和测试时,分割的依据文件
1.3 下载源码
在实验过程中,我们需要使用tsn-pytorch和mmaction的一些代码文件,所以我们提前从Git上获得存储在本地。
- 下载mmaction:
git clone --recursive https://github.com/open-mmlab/mmaction.git
- 下载tsn-pytorch
git clone --recursive https://github.com/yjxiong/tsn-pytorch
2. 数据处理
2.1 提帧
在我们下载好的UCF101数据集中,视频大多是长时间的,很难对其进行动作识别,所以需要进行提帧操作。
首先在mmaction的data/ucf101中创建rawframes、videos、annotations文件夹。
- rawframes:视频提帧后存放的文件目录
- videos:拷贝ucf101数据集中的101个文件目录,放置其中
- annotations:ucf101之后进行分割训练集、测试集的依据文件
然后在mmaction/data_tools/build_rawframes.py 就是进行视频提帧的代码文件,输入命令如下所示:
python build_rawframes.py ../data/ucf101/videos ../data/ucf101/rawframes/ --level 2 --ext avi
命令行窗口:
生成的文件目录形式如下所示:
运行完成后,将每一个视频的每一帧提取出来,放在特定名称的文件夹中。
2.2 生成file_list
在tsn-pytorch的readme文件中可以看到,训练过程中需要<ucf101_rgb_train_list>和<ucf101_rgb_val_list>,所以生成这两个list文件是必需的。使用mmaction/data_tools/buid_file_list.py即可对ucf101生成的帧进行训练集和测试集的划分。输入命令如下所示:
python data_tools/build_file_list.py ucf101 data/ucf101/rawframes/ --level 2 --format rawframes --shuffle
也可在mmaction/data_tools/ucf101/中输入:
bash generate_filelist.sh
命令行运行截图如下所示:
生成的filelist在data/ucf101目录下,形式如下:
file_list的内容如下所示:
file_list中有三列,第一列代表文件的地址,第二列代表视频的帧数,第三列代表视频的类别。这里仅仅使用ucf101的3个文件夹,所以类别只有0 1 2。
3. 训练部分
3.1 修改代码
根据tsn-pytorch的readme文件可知,训练部分需要对main.py文件进行运行,但需要针对自己的情况进行修改才可以成功运行。
- 在ucf101类别中,原本代码是101,我们这里复现只使用ucf101三个类型,所以将代码修改为
if args.dataset == 'ucf101'
num_class = 3
- 在TSNDataSet中,为了更好的找到对应文件的位置,建议将args.train_list和args.val_list(这两个输入字符串就是之前生成的file_list的绝对路径)写成指定字符串的形式,所以将代码修改为
TSNDataSet("", "/home/ty/mmaction/data/ucf101/ucf101_train_split1_rawframes", num_segments=args.num_segments,
new_length=data_length,
modality=args.modality,
image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
transform=torchvision.transforms.Compose([
train_augmentation,
Stack(roll=args.arch == 'BNInception'),
ToTorchFormatTensor(div=args.arch != 'BNInception'),
normalize,
]))
TSNDataSet("", "/home/ty/mmaction/data/ucf101/ucf101_val_split_1_rawframes.txt", num_segments=args.num_segments,
new_length=data_length,
modality=args.modality,
image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
random_shift=False,
transform=torchvision.transforms.Compose([
GroupScale(int(scale_size)),
GroupCenterCrop(crop_size),
Stack(roll=args.arch == 'BNInception'),
ToTorchFormatTensor(div=args.arch != 'BNInception'),
normalize,
]))
经过运行main.py可知,还需要对datase.py进行修改
在get函数中调用_load_image方法,直接运行main.py的话,会找不到image的路径,所以需要在之前写上image的根目录,然后和之后的路径拼接在一起即可。所以将代码修改为
def get(self, record, indices):
images = list()
for seg_ind in indices:
p = int(seg_ind)
for i in range(self.new_length):
seg_imgs = self._load_image('home/ty/mmaction/data/ucf101/rawframes' + record.path, p)
images.extend(seg_imgs)
if p < record.num_frames:
p += 1
process_data = self.transform(images)
return process_data, record.label
3.2 TSN训练
阅读tsn-pytorch的readme文件可知,在tsn-pytorch/打开命令行,输入命令:
python main.py ucf101 RGB /home/ty/mmaction/data/ucf101/ucf101_val_split_1_rawframes.txt /home/ty/mmaction/data/ucf101/ucf101_val_split_1_rawframes.txt \
--num_segments 3 \
--gd 20 --lr 0.001 --lr_steps 30 60 --epochs 5 \
-b 16 -j 8 --dropout 0.8 \
--snapshot_pref ucf101_bninception_
运行命令,命令行打印训练过程和结果并且保存训练好的模型文件:
观察命令行可以看到,训练了5个周期,最后得到的准确率为100%,loss为0.0072
上一篇: C语言---函数
下一篇: 在Filezilla下用sftp上传文件