欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

目标检测,关于Gluon的SSD跑自己的数据-只需要lst,不需要rec

程序员文章站 2022-04-05 09:59:06
...

关于SSD代码的有些问题可以看我的上一篇的部分代码解释 关于gluon中Object Detection 中SSD模型的笔记

当然还有一些地方没有写清楚,主要就是MultiBoxTarget和MultiBoxDetection两个函数。但是大致可以理解为第一个函数是在算anchor与label之间的有没有交叉的关系,毕竟是个多尺度的网络。第二个函数就是在得到网络计算出的offset基础上和已知的anchor基础上,生成网络预测的实际bouding box。仅供参考。后续有空会看一下源码或者更多的资料。

这里只要说一下数据准备阶段。

目标检测,关于Gluon的SSD跑自己的数据-只需要lst,不需要rec

可以看到代码主要用了一个函数ImageDetIter,上一篇代码讲解里面,我也提到了这个问题“多目标怎么办”。上面我的代码只跑了单目标的数据集,然后采用的ndarray的iterator。我之前一直也很迷惑为什么后面label要reshap成3*5。


目标检测,关于Gluon的SSD跑自己的数据-只需要lst,不需要rec目标检测,关于Gluon的SSD跑自己的数据-只需要lst,不需要rec

然后我就去看了这个函数,总结来说就是就是你只需要一个lst文件就可以生成一个iterator了,里面还会根据你的不同的目标,去reshape成不同的大小(空的补足-1)。

如何生成不同的list了,上面也给了说明,但是我觉得很不清楚啊,这个是空格呢,还是什么呢。反正就是我经过测试几次之后发现,每个符号之间隔‘\t’就可以了。每个字符的意思就是 xmin, ymin, xmax, ymax,这四个也应该是( xmin\t ymin\t xmax\t ymax

给你们一个样例,自己写这个一个.lst的文件就可以了

'1700\t4\t5\t804\t734\t0\t0.553482587065\t0.0940054495913\t0.813432835821\t0.298365122616\t/root/workspace/data/DeepBC/DeepBC_3D/dataset/processed/train/15_191.jpg'

1700 是编号,第1700张图片

4 据说一定要4, 他的意思是(2 + length of extra header), 2是 4和5,length of extra header是804,734,这里是图片的大小

5 就是每个label的组成,categories,xmin, ymin, xmax, ymax

后面的就是正常的坐标了

最后是这个label对应的图片,请写绝对路径

然后验证一下自己的lst是不是写对了

import mxnet as mx
data_iter = mx.image.ImageDetIter(batch_size=4, data_shape=(3,804, 734),path_imglist='./train.lst')
data_iter.reset()
for data in data_iter:
    d = data.data[0]
    l = data.label[0]
    print(d.shape)
    print(l.shape)
    break # 看一组就可以了

然后也可以看一下自己的图片的label对不对

import numpy as np
import matplotlib.pyplot as plt
idx = 0
img = d[idx].asnumpy()  # grab the first image, convert to numpy array
img = img.transpose((1, 2, 0))  # we want channel to be the last dimension #img += np.array([123, 117, 104])
img = img.astype(np.uint8)  # use uint8 (0-255)
for ll in l[idx].asnumpy():
    if ll[0]<0: #-1是补充的,不是label
        break
    else:
        print(ll)
        xmin= int(ll[1]*734) #因为我的x的shape是734
        ymin= int(ll[2]*804) #因为我的y的shape是804
        xmax= int(ll[3]*734) 
        ymax = int(ll[4]*804) 
        rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor=(1, 0, 0), linewidth=1)
        plt.gca().add_patch(rect)
        break

plt.imshow(img,'gray')# 我的是灰度图,如果你们是彩色图也可以做别的操作
plt.show()

你可以plot一下就知道自己传给网络的data到底是不是正确的了

如果有错误就请指出~