欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TSN实验过程

程序员文章站 2022-03-19 11:56:08
...

TSN实验过程
论文链接:https://arxiv.org/abs/1608.00859
代码链接:https://github.com/yjxiong/tsn-pytorch

1.数据准备阶段

1.1 数据集介绍

在视频分类项目中,有很多经典的公开数据集,目前主要的数据集如列表所示:

数据集 视频数 分类数 发布年 背景
KTH 600 6 2004 静态
HMDB51 6766 51 2011 动态
UCF101 13320 101 2012 动态
THUMOS-2014 18394 101 2014 动态
ActivityNet 27901 203 2015 动态
YouTube-8M 8264650 4800 2016 动态

本次TSN实验复现使用的是UCF101数据集。

UCF101是动作识别数据集,从Youtube收集而得,共包含101类动作。其中每类动作由25个人做动作,每人做4-7组,共13320个视频,分辨率为320*240。UCF101在动作的采集上具有非常大的多样性,包括相机运行、外观变化、姿态变化、物体比例变化、背景变化、光纤变化等。

101类动作可以分为5类:人与物体互动、人体动作、人与人互动、乐器演奏、体育运动。

TSN实验过程

1.2 下载数据集

下载网址:http://crcv.ucf.edu/data/UCF101/UCF101.rar

下载成功后的UCF文件夹如下所示:
该文件夹下是各种动作的视频文件,共有101种类别
TSN实验过程
下图是UCF101在进行训练和测试时,分割的依据文件
TSN实验过程

1.3 下载源码

在实验过程中,我们需要使用tsn-pytorch和mmaction的一些代码文件,所以我们提前从Git上获得存储在本地。

  • 下载mmaction:
git clone --recursive https://github.com/open-mmlab/mmaction.git
  • 下载tsn-pytorch
git clone --recursive https://github.com/yjxiong/tsn-pytorch

2. 数据处理

2.1 提帧

在我们下载好的UCF101数据集中,视频大多是长时间的,很难对其进行动作识别,所以需要进行提帧操作。
首先在mmaction的data/ucf101中创建rawframes、videos、annotations文件夹。

  • rawframes:视频提帧后存放的文件目录
  • videos:拷贝ucf101数据集中的101个文件目录,放置其中
  • annotations:ucf101之后进行分割训练集、测试集的依据文件
    TSN实验过程

然后在mmaction/data_tools/build_rawframes.py 就是进行视频提帧的代码文件,输入命令如下所示:

python build_rawframes.py ../data/ucf101/videos ../data/ucf101/rawframes/ --level 2  --ext avi

命令行窗口:
TSN实验过程

生成的文件目录形式如下所示:
TSN实验过程
TSN实验过程
TSN实验过程
运行完成后,将每一个视频的每一帧提取出来,放在特定名称的文件夹中。

2.2 生成file_list

在tsn-pytorch的readme文件中可以看到,训练过程中需要<ucf101_rgb_train_list>和<ucf101_rgb_val_list>,所以生成这两个list文件是必需的。使用mmaction/data_tools/buid_file_list.py即可对ucf101生成的帧进行训练集和测试集的划分。输入命令如下所示:

python data_tools/build_file_list.py ucf101 data/ucf101/rawframes/ --level 2 --format rawframes --shuffle

也可在mmaction/data_tools/ucf101/中输入:

bash generate_filelist.sh

命令行运行截图如下所示:
TSN实验过程
生成的filelist在data/ucf101目录下,形式如下:
TSN实验过程
file_list的内容如下所示:
TSN实验过程
file_list中有三列,第一列代表文件的地址,第二列代表视频的帧数,第三列代表视频的类别。这里仅仅使用ucf101的3个文件夹,所以类别只有0 1 2。

3. 训练部分

3.1 修改代码

根据tsn-pytorch的readme文件可知,训练部分需要对main.py文件进行运行,但需要针对自己的情况进行修改才可以成功运行。

  • 在ucf101类别中,原本代码是101,我们这里复现只使用ucf101三个类型,所以将代码修改为
if args.dataset == 'ucf101'
	num_class = 3
  • 在TSNDataSet中,为了更好的找到对应文件的位置,建议将args.train_list和args.val_list(这两个输入字符串就是之前生成的file_list的绝对路径)写成指定字符串的形式,所以将代码修改为
TSNDataSet("", "/home/ty/mmaction/data/ucf101/ucf101_train_split1_rawframes", num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       normalize,
                   ]))
TSNDataSet("", "/home/ty/mmaction/data/ucf101/ucf101_val_split_1_rawframes.txt", num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+"{}_{:05d}.jpg",
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       normalize,
                   ]))

经过运行main.py可知,还需要对datase.py进行修改
在get函数中调用_load_image方法,直接运行main.py的话,会找不到image的路径,所以需要在之前写上image的根目录,然后和之后的路径拼接在一起即可。所以将代码修改为

def get(self, record, indices):

        images = list()
        for seg_ind in indices:
            p = int(seg_ind)
            for i in range(self.new_length):
                seg_imgs = self._load_image('home/ty/mmaction/data/ucf101/rawframes' + record.path, p)
                images.extend(seg_imgs)
                if p < record.num_frames:
                    p += 1

        process_data = self.transform(images)
        return process_data, record.label

3.2 TSN训练

阅读tsn-pytorch的readme文件可知,在tsn-pytorch/打开命令行,输入命令:

python main.py ucf101 RGB /home/ty/mmaction/data/ucf101/ucf101_val_split_1_rawframes.txt /home/ty/mmaction/data/ucf101/ucf101_val_split_1_rawframes.txt \
   --num_segments 3 \
   --gd 20 --lr 0.001 --lr_steps 30 60 --epochs 5 \
   -b 16 -j 8 --dropout 0.8 \
   --snapshot_pref ucf101_bninception_ 

运行命令,命令行打印训练过程和结果并且保存训练好的模型文件:
TSN实验过程
TSN实验过程
观察命令行可以看到,训练了5个周期,最后得到的准确率为100%,loss为0.0072