欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

在谷歌目标检测(Google object_detection) API 上训练自己的数据集

程序员文章站 2024-03-14 10:09:58
...

本文未经同意禁止转载,谢谢配合!

知乎链接:https://zhuanlan.zhihu.com/p/28218410

应公司要求,利用谷歌最近开源的Google object_detection API对公司收集的数据集进行训练,并检测训练效果。通过一两天的研究以及维持四天的训练(GTX 1060  6GB),终于成功的在自己数据集上训练的任务。测试效果感觉还行,虽没有达到谷歌官方公布的数据集上跑的识别效果,但是识别率也还过得去,这主要是因为数据集没有官方做的那么规范。下图为本人挑选的一张识别率较好的图片(识别哈尔滨啤酒):

在谷歌目标检测(Google object_detection) API 上训练自己的数据集

下面把本人如何一步步在自己的数据集上训练的详细步骤做个总结,一是方便自己以后操作起来更快的再次上手训练,二是方便大家能好的实现该API的一些需求。

需要说明的:

1:本教程用的模型权重参数为faster_rcnn_resnet101_coco  ,可点击进行模型的下载。

2:数据集格式需要为转换成tensorflow要求的tfrecord的形式。

3:本文在GTX 1060  6GB的显卡上训练了四天

4:如何安装tensorflow等一些依赖库,本文不再赘述,请参考:安装依赖库教程链接


过程:

1:下载Google object_detection API

下载地址

2:数据集准备:

数据集需要符合API所需的TFRecord格式,官方提供的数据集格式为PASCAL VOC格式,API已经为我们提供了将此格式转为TFRecord的代码. 但是这里我们需要注意一个细节:create_pascal_tf_record.py中的

examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
                             'aeroplane_' + FLAGS.set + '.txt')
去掉'aeroplane_'。

同时,将文件中的


flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt',
                    'Path to label map proto')
 

data/pascal_labe_map.pbtxt改为自己的数据集label

然后在“tensorflow/models/object_detection/”目录下运行以下命令

#生成训练集record
python create_pascal_tf_record.py --data_dir=`自己的训练数据集路径` \
    --year=VOC2007 --set=train --output_path=`你想保存的训练集的record路径`

#生成验证集record
 python create_pascal_tf_record.py --data_dir=`自己的验证数据集路径` \

    --year=VOC2007 --set=val --output_path=`你想保存的验证集的record路径`

注意,在data目录下选择一个.pbtxt文件,将该文件改为自己数据集的label。
执行上述两个命令后会在data文件夹下生成两个record文件。 

3:下载预训练模型

按照上文“需要说明的”第一条下载预训练模型,将下载好的模型进行解压,并将.ckpt的三个文件拷贝到models目录下。将object_detection/samples/configs/faster_rcnn_resnet101_voc07.config复制到models目录下并做如下修改:
1)num_classes:修改为之前修的的.pbtxt文件中的类别数目
2)将所有'PATH_TO_BE_CONFIGURED'修改为自己之前设置的路径

4:开始训练

执行上述三步之后我们可以开始训练了,此处需要注意两点,不然会出现模块导出错误,在tensorflow/models分别运行:


protoc object_detection/protos/*.proto --python_out=.

export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

然后进入到obeject_detection目录下,运行一下命令:


python train.py --train_dir='想要保存训练模型的路径' --pipeline_config_path='你采用的.config文件路径'

5:模型可视化

运行上述四步之后您基本上只需等着模型运行完成即可,如果您想要可视化您的模型,可以运行:

tensorboard --logdir=’上面第4点提到的train_dir路径‘

然后在你的浏览器输入0.0.0.0:6006就能看到模型一些相关参数的可视化结果了。

训练完成后会生成三个.cpkt的文件,将这三个文件复制到tensorflow/models下,可利用这三个文件生成一个.pb文件,生成代码如下:

python object_detection/export_inference_graph.py \
    --input_type image_tensor \
    --pipeline_config_path ’你的.config文件路径’ \
    --checkpoint_path model.ckpt-‘CHECKPOINT_NUMBER’ \
    --inference_graph_path output_inference_graph.pb
这样你就可以利用.pb文件进行目标检测了,具体步骤请参考:https://github.com/tensorflow/models/blob/master/object_detection/object_detection_tutorial.ipynb

6:参考

https://zhuanlan.zhihu.com/p/27469690

https://github.com/tensorflow/models/blob/master/object_detection/g3doc/installation.md

https://github.com/tensorflow/models/blob/master/object_detection/g3doc/running_pets.md


如您觉得本文对你有帮助,请酌情赞赏。同时本文如有不完善的地方欢迎指正!谢谢!

在谷歌目标检测(Google object_detection) API 上训练自己的数据集