欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【SSD】pytorch版本的SSD训练

程序员文章站 2024-03-16 23:39:58
...

调试的代码源码:https://github.com/amdegroot/ssd.pytorch

环境:

  • python3.7
  • cuda10.0
  • cudnn7
  • pytorch1.2.0
  • torchvision0.4.0

问题1

使用的是VOC,没有COCO数据,那么就需要将COCO的部分注释掉,如果不注释就会报错。

解决方法

将train.py中的COCO的部分直接注释掉,修改如下

    # if args.dataset == 'COCO':
    #     if args.dataset_root == VOC_ROOT:
    #         if not os.path.exists(COCO_ROOT):
    #             parser.error('Must specify dataset_root if specifying dataset')
    #         print("WARNING: Using default COCO dataset_root because " +
    #               "--dataset_root was not specified.")
    #         args.dataset_root = COCO_ROOT
    #     cfg = coco
    #     dataset = COCODetection(root=args.dataset_root,
    #                             transform=SSDAugmentation(cfg['min_dim'], MEANS)
    #                             )
    # elif args.dataset == 'VOC':
    #     if args.dataset_root == COCO_ROOT:
    #         parser.error('Must specify dataset if specifying dataset_root')
    #     cfg = voc
    #     dataset = VOCDetection(root=args.dataset_root,
    #                            transform=SSDAugmentation(cfg['min_dim'], MEANS)
    #                            )

    cfg = voc
    dataset = VOCDetection(root=args.dataset_root,
                           transform=SSDAugmentation(cfg['min_dim'], MEANS)
                           )

 将data\__init__.py中的COCO部分注释掉:

# from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT, get_label_map

将data\config.py中的HOME目录修改:

# gets home dir cross platform
# HOME = os.path.expanduser("~")
HOME =  r"E:\standard_data\voc"

将data\voc0712.py的VOC_ROOT目录修改为:

VOC_ROOT = osp.join(HOME, "VOC2007_ORI\\VOCdevkit\\")

问题2

在ssds检测项目中,其中求损失的multibox_loss.py中遇到一个bug,错误为:

IndexError: The shape of the mask [32, 2990] at index 0 does not match the shape of the indexed tensor [95680, 1] at index 0

报错代码出现在:

loss_c[pos] = 0

其中,pos和loss_c的尺寸维度分别是:

loss_c.size torch.Size([95680, 1])
pos.size torch.Size([32, 2990])

解决方法

在multibox_loss.py中的97行左右,将

loss_c[pos] = 0 # filter out pos boxes for now

改为

loss_c[pos.view(-1)] = 0  # filter out pos boxes for now

问题3

 出现问题地点

        loc_loss += loss_l.data[0]
        conf_loss += loss_c.data[0]
        if iteration % 1 == 0:
            print('timer: %.4f sec.' % (t1 - t0))
            print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ')

解决方法

将上面问题地点修改为

        loc_loss += loss_l.detach()
        # conf_loss += loss_c.data[0]
        conf_loss += loss_c.detach()
        if iteration % 1 == 0:
            print('timer: %.4f sec.' % (t1 - t0))
            print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.detach()), end=' ')