在谷歌目标检测(Google object_detection) API 上训练自己的数据集
本文未经同意禁止转载,谢谢配合!
知乎链接:https://zhuanlan.zhihu.com/p/28218410
应公司要求,利用谷歌最近开源的Google object_detection API对公司收集的数据集进行训练,并检测训练效果。通过一两天的研究以及维持四天的训练(GTX 1060 6GB),终于成功的在自己数据集上训练的任务。测试效果感觉还行,虽没有达到谷歌官方公布的数据集上跑的识别效果,但是识别率也还过得去,这主要是因为数据集没有官方做的那么规范。下图为本人挑选的一张识别率较好的图片(识别哈尔滨啤酒):
下面把本人如何一步步在自己的数据集上训练的详细步骤做个总结,一是方便自己以后操作起来更快的再次上手训练,二是方便大家能好的实现该API的一些需求。
需要说明的:
1:本教程用的模型权重参数为faster_rcnn_resnet101_coco ,可点击进行模型的下载。
2:数据集格式需要为转换成tensorflow要求的tfrecord的形式。
3:本文在GTX 1060 6GB的显卡上训练了四天
4:如何安装tensorflow等一些依赖库,本文不再赘述,请参考:安装依赖库教程链接
过程:
1:下载Google object_detection API:
下载地址
2:数据集准备:
数据集需要符合API所需的TFRecord格式,官方提供的数据集格式为PASCAL VOC格式,API已经为我们提供了将此格式转为TFRecord的代码. 但是这里我们需要注意一个细节:create_pascal_tf_record.py中的
examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main', 'aeroplane_' + FLAGS.set + '.txt')去掉'aeroplane_'。
同时,将文件中的
flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt', 'Path to label map proto')
data/pascal_labe_map.pbtxt改为自己的数据集label
然后在“tensorflow/models/object_detection/”目录下运行以下命令
#生成训练集record
python create_pascal_tf_record.py --data_dir=`自己的训练数据集路径` \
--year=VOC2007 --set=train --output_path=`你想保存的训练集的record路径`
#生成验证集record
python create_pascal_tf_record.py --data_dir=`自己的验证数据集路径` \
--year=VOC2007 --set=val --output_path=`你想保存的验证集的record路径`
注意,在data目录下选择一个.pbtxt文件,将该文件改为自己数据集的label。
执行上述两个命令后会在data文件夹下生成两个record文件。
3:下载预训练模型
按照上文“需要说明的”第一条下载预训练模型,将下载好的模型进行解压,并将.ckpt的三个文件拷贝到models目录下。将object_detection/samples/configs/faster_rcnn_resnet101_voc07.config复制到models目录下并做如下修改:
1)num_classes:修改为之前修的的.pbtxt文件中的类别数目
2)将所有'PATH_TO_BE_CONFIGURED'修改为自己之前设置的路径
4:开始训练
执行上述三步之后我们可以开始训练了,此处需要注意两点,不然会出现模块导出错误,在tensorflow/models
分别运行:
protoc object_detection/protos/*.proto --python_out=.
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
然后进入到obeject_detection目录下,运行一下命令:
python train.py --train_dir='想要保存训练模型的路径' --pipeline_config_path='你采用的.config文件路径'
5:模型可视化
运行上述四步之后您基本上只需等着模型运行完成即可,如果您想要可视化您的模型,可以运行:
tensorboard --logdir=’上面第4点提到的train_dir路径‘
然后在你的浏览器输入0.0.0.0:6006就能看到模型一些相关参数的可视化结果了。
训练完成后会生成三个.cpkt的文件,将这三个文件复制到tensorflow/models下,可利用这三个文件生成一个.pb文件,生成代码如下:
python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path ’你的.config文件路径’ \
--checkpoint_path model.ckpt-‘CHECKPOINT_NUMBER’ \
--inference_graph_path output_inference_graph.pb
这样你就可以利用.pb文件进行目标检测了,具体步骤请参考:https://github.com/tensorflow/models/blob/master/object_detection/object_detection_tutorial.ipynb6:参考
https://zhuanlan.zhihu.com/p/27469690
https://github.com/tensorflow/models/blob/master/object_detection/g3doc/installation.md
https://github.com/tensorflow/models/blob/master/object_detection/g3doc/running_pets.md
如您觉得本文对你有帮助,请酌情赞赏。同时本文如有不完善的地方欢迎指正!谢谢!
上一篇: vue移动端项目接入vconsole(移动端调试)
下一篇: vue移动端项目适配