【SSD】pytorch版本的SSD训练
程序员文章站
2024-03-16 23:39:58
...
调试的代码源码:https://github.com/amdegroot/ssd.pytorch
环境:
- python3.7
- cuda10.0
- cudnn7
- pytorch1.2.0
- torchvision0.4.0
问题1
使用的是VOC,没有COCO数据,那么就需要将COCO的部分注释掉,如果不注释就会报错。
解决方法
将train.py中的COCO的部分直接注释掉,修改如下
# if args.dataset == 'COCO':
# if args.dataset_root == VOC_ROOT:
# if not os.path.exists(COCO_ROOT):
# parser.error('Must specify dataset_root if specifying dataset')
# print("WARNING: Using default COCO dataset_root because " +
# "--dataset_root was not specified.")
# args.dataset_root = COCO_ROOT
# cfg = coco
# dataset = COCODetection(root=args.dataset_root,
# transform=SSDAugmentation(cfg['min_dim'], MEANS)
# )
# elif args.dataset == 'VOC':
# if args.dataset_root == COCO_ROOT:
# parser.error('Must specify dataset if specifying dataset_root')
# cfg = voc
# dataset = VOCDetection(root=args.dataset_root,
# transform=SSDAugmentation(cfg['min_dim'], MEANS)
# )
cfg = voc
dataset = VOCDetection(root=args.dataset_root,
transform=SSDAugmentation(cfg['min_dim'], MEANS)
)
将data\__init__.py中的COCO部分注释掉:
# from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT, get_label_map
将data\config.py中的HOME目录修改:
# gets home dir cross platform
# HOME = os.path.expanduser("~")
HOME = r"E:\standard_data\voc"
将data\voc0712.py的VOC_ROOT目录修改为:
VOC_ROOT = osp.join(HOME, "VOC2007_ORI\\VOCdevkit\\")
问题2
在ssds检测项目中,其中求损失的multibox_loss.py中遇到一个bug,错误为:
IndexError: The shape of the mask [32, 2990] at index 0 does not match the shape of the indexed tensor [95680, 1] at index 0
报错代码出现在:
loss_c[pos] = 0
其中,pos和loss_c的尺寸维度分别是:
loss_c.size torch.Size([95680, 1])
pos.size torch.Size([32, 2990])
解决方法
在multibox_loss.py中的97行左右,将
loss_c[pos] = 0 # filter out pos boxes for now
改为
loss_c[pos.view(-1)] = 0 # filter out pos boxes for now
问题3
出现问题地点
loc_loss += loss_l.data[0]
conf_loss += loss_c.data[0]
if iteration % 1 == 0:
print('timer: %.4f sec.' % (t1 - t0))
print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ')
解决方法
将上面问题地点修改为
loc_loss += loss_l.detach()
# conf_loss += loss_c.data[0]
conf_loss += loss_c.detach()
if iteration % 1 == 0:
print('timer: %.4f sec.' % (t1 - t0))
print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.detach()), end=' ')
上一篇: 设计模式六大原则,你真的懂了吗?
下一篇: Volatile你真的懂了吗?