欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

SSD训练自己的数据集(pytorch版)

程序员文章站 2024-03-16 23:14:04
...
  • 环境:Win10+Anaconda3+Python3.8.8+Pytorch1.8.1
  • ssd.pytorch代码下载地址:https://github.com/amdegroot/ssd.pytorch
  • VGG16_reducedfc.pth预训练模型下载地址:https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
  • 修改版代码(Pytorch高于1.3.0需要)下载地址:https://github.com/sayakbanerjee1999/Single-Shot-Object-Detection-Updated

一、数据集准备(Pascal格式)

  1. 数据集存放位置为ssd.pytorch-master/data目录下,也可以通过修改voc0712.py文件中的VOC_ROOT = osp.join(HOME, "data/VOCdevkit/")来指定数据集存放路径。数据集文件夹格式如下:

    VOCdevkit
    --VOC2020
    ----Annotations
    ----ImageSets
    ----JPEGImages
    
  2. 部分数据集存在xml文件中没有标记数据的情况,也就是没有<object>标签数据,会出现IndexError: too many indices for array:...的数组索引出错,需要使用下面的代码检测出错的xml文件,之后可以手动修改或删除出错的xml文件。

    import argparse
    import sys
    import cv2
    import os
    import os.path as osp
    import numpy as np
    
    if sys.version_info[0] == 2:
        import xml.etree.cElementTree as ET
    else:
        import xml.etree.ElementTree  as ET
    
    parser = argparse.ArgumentParser(
                description='Single Shot MultiBox Detector Training With Pytorch')
    train_set = parser.add_mutually_exclusive_group()
    parser.add_argument('--root', default='VOCdevkit/VOC2020' , help='Dataset root directory path')
    args = parser.parse_args()
    
    CLASSES = [('person')]
    annopath = osp.join('%s', 'Annotations', '%s.{}'.format("xml"))
    imgpath  = osp.join('%s', 'JPEGImages',  '%s.{}'.format("jpg"))
    
    def vocChecker(image_id, width, height, keep_difficult = False):
        target   = ET.parse(annopath % image_id).getroot()
        res      = []
        for obj in target.iter('object'):
            difficult = int(obj.find('difficult').text) == 1
            if not keep_difficult and difficult:
                continue
            name = obj.find('name').text.lower().strip()
            bbox = obj.find('bndbox')
            pts    = ['xmin', 'ymin', 'xmax', 'ymax']
            bndbox = []
            for i, pt in enumerate(pts):
                cur_pt = int(bbox.find(pt).text) - 1
                # scale height or width
                cur_pt = float(cur_pt) / width if i % 2 == 0 else float(cur_pt) / height
                bndbox.append(cur_pt)
            label_idx =  dict(zip(CLASSES, range(len(CLASSES))))[name]
            bndbox.append(label_idx)
            res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]
            # img_id = target.find('filename').text[:-4]
        try :
            np.array(res)[:,4]
            np.array(res)[:,:4]
        except IndexError:
            print(image_id+" had error index")
        return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]
    
    if __name__ == '__main__' :
        i = 0
        for name in sorted(os.listdir(osp.join(args.root,'Annotations'))):
        # as we have only one annotations file per image
            i += 1
            img    = cv2.imread(imgpath  % (args.root,name.split('.')[0]))
            height, width, channels = img.shape
            res = vocChecker((args.root, name.split('.')[0]), height, width)
        print("Total of annotations : {}".format(i))
    

二、代码修改

  1. 修改data/voc0712.py文件中的VOC_CLASSES 变量。例如,将VOC_CLASSES修改为person类,注意如果只有一类则需要加方括号,修改后的结果如下。

    VOC_CLASSES = [('person')
    
  2. 修改voc0712.py文件中VOCDetection类的__init__函数,将image_sets修改为[('2020', 'train'), ('2020', 'val'),('2020','test')],修改后的结果如下。

    def __init__(self, root,
    	image_sets=[('2020', 'train'), ('2020', 'val'),('2020','test')],
    	transform=None, target_transform=VOCAnnotationTransform(),
    	dataset_name='VOC0712'):
    
  3. 修改config.py文件中的voc字典变量。将其中的num_classes修改为2(背景类和person类),第一次调试时可以将max_iter调小至1000,修改后的结果如下。

    voc = {
        'num_classes': 2,
        'lr_steps': (80000, 100000, 120000),
        'max_iter': 1000,
        'feature_maps': [38, 19, 10, 5, 3, 1],
        'min_dim': 300,
        'steps': [8, 16, 32, 64, 100, 300],
        'min_sizes': [30, 60, 111, 162, 213, 264],
        'max_sizes': [60, 111, 162, 213, 264, 315],
        'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
        'variance': [0.1, 0.2],
        'clip': True,
        'name': 'VOC',
    }
    
  4. coco_labels.txt放在ssd.pytorch-master/data/coco/目录下,也可以通过修改coco.py文件中的COCO_ROOT = osp.join(HOME, 'data/coco/')来指定存放路径。

  5. 在Pytorch1.3以上版本运行时,会出现RuntimeError: Legacy autograd function with non-static forward method is deprecated错误,原因是当前版本要求forward过程是静态的,所以需要将原代码进行修改。
    layers/functions/detection.py文件替换为Single-Shot-Object-Detection-Updated-master中的detection.py文件。
    修改ssd.py文件中SSD类的__init__函数和forward函数,修改后的结果如下。

    if phase == 'test':
    	self.softmax = nn.Softmax(dim=-1)
        self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)
    修改为:
    if phase == 'test':
    	self.softmax = nn.Softmax()
    	self.detect = Detect()
    	
    if self.phase == "test":
    	output = self.detect(
        	loc.view(loc.size(0), -1, 4),                   # loc preds
            self.softmax(conf.view(conf.size(0), -1,
            			 self.num_classes)),                # conf preds
            self.priors.type(type(x.data))                  # default boxes
        )
    修改为:
    if self.phase == "test":
    	output = self.detect.apply(21, 0, 200, 0.01, 0.45,
       		loc.view(loc.size(0), -1, 4),                   # loc preds
            self.softmax(conf.view(-1,21)),                 # conf preds
    		self.priors.type(type(x.data))                  # default boxes
    	)
    
  6. 修改train.py中187至189行代码,原因是.data[0]写法适用于低版本Pytorch,否则会出现IndexError:invalid index of a 0-dim tensor...错误,修改后的结果如下。

    loc_loss += loss_l.item()
    conf_loss += loss_c.item()
    
    if iteration % 10 == 0:
        print('timer: %.4f sec.' % (t1 - t0))
    	print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.item()), end=' ')
    
  7. 修改train.py中165行代码,否则会出现StopInteration...训练中断错误,修改后的结果如下。

    try:
        images, targets = next(batch_iterator)
    except:
        batch_iterator=iter(data_loader)
        images, targets = next(batch_iterator)
    
  8. 交换layers/modules/multibox_loss.py中97行和98代码位置,否则会出现IndexError: The shape of the mask [14, 8732] at index 0does...错误,修改后的结果如下。

    loss_c = loss_c.view(num, -1)
    loss_c[pos] = 0  # filter out pos boxes for now
    
  9. (可选)可以根据自己的需要对train.py中预训练模型、batch_size、学习率、模型名字和模型保存的次数等参数进行修改。例如,建议学习率修改为1e-4(原因是原版使用1e-3可能会出现loss为nan情况),第一次调试时可以修改为每迭代100次保存,方便调试。

    parser.add_argument('--basenet', default='vgg16_reducedfc.pth',
                        help='Pretrained base model')
    parser.add_argument('--batch_size', default=32, type=int,
                        help='Batch size for training')
    parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
                        help='initial learning rate')
    if iteration != 0 and iteration % 5000 == 0:
    	print('Saving state, iter:', iteration)
    	torch.save(ssd_net.state_dict(), 'weights/ssd300_VOC_' + repr(iteration) + '.pth')
    

三、模型训练

  1. 使用以下命令即可开始训练模型,训练时输出内容如下,可能会出现一些UserWarning...提示,可以不必理会。
    python train.py
    
    Loading base network...
    Initializing weights...
    Loading the dataset...
    Training SSD on: VOC0712
    Using the specified args:
    Namespace(basenet='vgg16_reducedfc.pth', batch_size=32, cuda=True, dataset='VOC', dataset_root='/home/featurize/data/VOCdevkit/', gamma=0.1, lr=0.0001, momentum=0.9, num_workers=4, resume=None, save_folder='weights/', start_iter=0, visdom=False, weight_decay=0.0005)
    timer: 3.1155 sec.
    iter 0 || Loss: 21.9156 || timer: 0.2166 sec.
    iter 10 || Loss: 12.6344 || timer: 0.1833 sec.
    iter 20 || Loss: 9.5942 || timer: 0.1928 sec.
    iter 30 || Loss: 8.4343 || timer: 0.1805 sec.
    iter 40 || Loss: 7.2879 || timer: 0.1652 sec.
    iter 50 || Loss: 6.3154 || timer: 0.1877 sec.
    iter 60 || Loss: 6.3842 || timer: 0.2185 sec.
    iter 70 || Loss: 6.2072 || timer: 0.1802 sec.
    iter 80 || Loss: 6.4077 || timer: 0.1642 sec.
    iter 90 || Loss: 5.8568 || timer: 0.1848 sec.
    iter 100 || Loss: 5.7892 || timer: 0.1976 sec.
    iter 110 || Loss: 5.7187 || timer: 0.2656 sec.
    iter 120 || Loss: 5.3061 || timer: 0.1737 sec.
    ...