欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

利用TensorFlow Object Detection API的预训练模型训练自己的数据

程序员文章站 2022-06-15 11:52:53
利用TensorFlow Object Detection API的预训练模型训练自己的数据文章目录利用TensorFlow Object Detection API的预训练模型训练自己的数据1.前言介绍2.前期准备2.1环境搭建2.2数据准备2.3模型准备3.训练过程3.1修改配置文件(config文件)3.2开始训练3.3保存模型3.4Tensorboard实时查看训练效果4.测试结果1.前言介绍pb文件为训练好的模型,可以直接拿来使用ckpt文件就是预训练模型,用来训练自己的数据2.前期...

利用TensorFlow Object Detection API的预训练模型训练自己的数据

1.前言介绍

  • pb文件为训练好的模型,可以直接拿来使用
  • ckpt文件就是预训练模型,用来训练自己的数据

2.前期准备

  • 准备一个保存收集图片的文件夹,包含Image和Annotations,分别用来保存图片和标注后的xml文件
  • 另外准备一个文件夹放训练有关的数据,里面包含三个下属文件data,export,model,分别用来存放训练可用的数据,生成的最终模型,训练产生的文件

2.1环境搭建

  • 配置Tensorflow环境,Windows或Ubuntu都可

2.2数据准备

  1. 根据自己训练需要收集所需要的图片

  2. 将所收集的图片进行排序后进行筛选然后再排序

    如果是处理自己采集的数据集,一定要先排序再筛选!!否则可能会遗漏掉一些本该筛选的图片在标注时增加自己的工作量

    我用的方法是按顺序对所有文件进行重命名

    import os
    i = 1
    for filename in os.listdir('D:/DataCollection/hand_data/Image/test/'):
    	newname = str(i) + '.jpg'
    	print(newname)
    	os.rename('D:/DataCollection/hand_data/Image/test/'+filename, 'D:/DataCollection/hand_data/Image/test/'+newname)
    	i += 1
    
  3. 对排序后的图片进行标注

    标注图片用的软件是labelImg,可以选择标注的图片位置(Image),以及生成的xml文件保存的位置即Annotations文件夹

    W是标注 D是下一张 A是上一张 空格保存

  4. 格式转换

    这里的生成的csv以及tfrecord文件都放在data文件夹下

    图片需转换成tensorflow可以识别的格式

    • 先由xml转为csv

      """
      将文件夹内所有XML文件的信息记录到CSV文件中
      """
      
      import os
      import glob
      import pandas as pd
      import xml.etree.ElementTree as ET
       
      os.chdir('E:/tensorflow/hand_data_new/hand_data/test')  #xml文件保存路径 使用时需改为自己的路径
      path = 'E:/tensorflow/hand_data_new/hand_data/test'
      
      
      def xml_to_csv(path):
          xml_list = []
          for xml_file in glob.glob(path + '/*.xml'):
              tree = ET.parse(xml_file)
              root = tree.getroot()
              print('test')
              for member in root.findall('object'):
                  value = (root.find('filename').text,
                           int(root.find('size')[0].text),
                           int(root.find('size')[1].text),
                           member[0].text,
                           int(member[4][0].text),
                           int(member[4][1].text),
                           int(member[4][2].text),
                           int(member[4][3].text)
                           )
                  xml_list.append(value)
          column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
          xml_df = pd.DataFrame(xml_list, columns=column_name)
          return xml_df
      
      
      def main():
          image_path = path
          xml_df = xml_to_csv(image_path)
          xml_df.to_csv('E:/tensorflow/hand_set/data/eval.csv', index=None)  #得到的csv文件保存路径
          print('Successfully converted xml to csv.')
      
      main()
      
    • 然后将csv文件转为tfrecord

      from __future__ import division
      from __future__ import print_function
      from __future__ import absolute_import
      
      import os
      import io
      import pandas as pd
      import tensorflow as tf
      
      from PIL import Image
      from object_detection.utils import dataset_util
      from collections import namedtuple, OrderedDict
      
      flags = tf.app.flags
      
      flags.DEFINE_string('csv_input', 'E:/tensorflow/hand_set/data/eval.csv', 'Path to the CSV input')#csv文件
      flags.DEFINE_string('output_path', 'E:/tensorflow/hand_set/data/eval.record', 'Path to output TFRecord')#TFRecord文件
      flags.DEFINE_string('image_dir', 'E:/tensorflow/hand_data_new/hand_data/Image/TEST', 'Path to images')#对应的图片位置
      
      FLAGS = flags.FLAGS
      
      # TO-DO replace this with label map
      #从1开始根据自己训练的类别数和标签来写
      def class_text_to_int(row_label):
          if row_label == 'DOWN':
              return 1
          elif row_label == 'FIVE':
              return 2
          else:
              None
      
      def split(df, group):
          data = namedtuple('data', ['filename', 'object'])
          gb = df.groupby(group)
          return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
      
      
      def create_tf_example(group, path):
          with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
      
              encoded_jpg = fid.read()
      
          encoded_jpg_io = io.BytesIO(encoded_jpg)
          image = Image.open(encoded_jpg_io)
          width, height = image.size
      
          filename = group.filename.encode('utf8')
          image_format = b'jpg'
          xmins = []
          xmaxs = []
          ymins = []
          ymaxs = []
          classes_text = []
          classes = []
      
          for index, row in group.object.iterrows():
              xmins.append(row['xmin'] / width)
              xmaxs.append(row['xmax'] / width)
              ymins.append(row['ymin'] / height)
              ymaxs.append(row['ymax'] / height)
              classes_text.append(row['class'].encode('utf8'))
              classes.append(class_text_to_int(row['class']))
      
          tf_example = tf.train.Example(features=tf.train.Features(feature={
      
              'image/height': dataset_util.int64_feature(height),
      
              'image/width': dataset_util.int64_feature(width),
      
              'image/filename': dataset_util.bytes_feature(filename),
      
              'image/source_id': dataset_util.bytes_feature(filename),
      
              'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      
              'image/format': dataset_util.bytes_feature(image_format),
      
              'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
      
              'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
      
              'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
      
              'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
      
              'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      
              'image/object/class/label': dataset_util.int64_list_feature(classes),
      
          }))
      
          return tf_example
      
      
      def main(_):
      
          writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
      
          path = os.path.join(FLAGS.image_dir)
      
          examples = pd.read_csv(FLAGS.csv_input)
      
          grouped = split(examples, 'filename')
      
          for group in grouped:
      
              tf_example = create_tf_example(group, path)
      
              writer.write(tf_example.SerializeToString())
      
      
          writer.close()
      
          output_path = os.path.join(os.getcwd(), FLAGS.output_path)
      
          print('Successfully created the TFRecords: {}'.format(output_path))
      
      
      
      if __name__ == '__main__':
      
          tf.app.run()
  5. 训练数据准备完以后还需要准备一个pbtxt文件

    例如hand.pbtxt,放在data文件夹里

    内容如下,根据自己的类别数而定

    item {
      id: 1
      name: 'DOWN'
    }
    
    item{
      id: 2
      name: 'FIVE'
    }
    

2.3模型准备

下载Tensorflow模型

下载地址:https://github.com/tensorflow/models

下载protoc

下载地址:https://github.com/protocolbuffers/protobuf/releases

下载预训练模型

下载地址:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md

  1. 利用protoc编译TensorFlow Object Detection API,转换为py文件
  2. 建立一个专门的文件夹来保存预训练模型,记住下载路径,之后会用到里面的ckpt文件
  3. 在下载的Tensorflow模型的文件下找到models\research\object_detection\samples\configs,在里面找到自己所用的预训练模型对应的config文件,拷贝一份放在最初建立的model文件夹下

3.训练过程

3.1修改配置文件(config文件)

​ 以下仅仅是列出最主要的修改,其他有关训练配置可根据实际情况再做调整

  • 改成自己训练的类别数量

    num_classes: 4
    
  • 根据自己的机器性能适当修改也可以不改

    batch_size: 24
    
  • 改成所用的预训练模型路径

    fine_tune_checkpoint: "E:/tensorflow/pretrained_models/ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/model.ckpt"
    
  • 训练所需的tfrecord文件路径,测试的则改为测试集的路径

      tf_record_input_reader {
        input_path: "E:/tensorflow/hand_set/data/train.record"
      }
    
  • 标签映射文件,即pbtxt文件位置,训练与测试共用一个

     label_map_path: "E:/tensorflow/hand_set/data/object_detection.pbtxt"
    

3.2开始训练

执行语句

python E:/tensorflow/models/research/object_detection/legacy/train.py --train_dir=E:/tensorflow/hand_set/model/model5 --pipeline_config_path=E:/tensorflow/hand_set/model/ssd_mobilenet_v2_quantized_300x300_coco.config --logtostderr
  • train.py在下载的Tensorflow模型文件夹下
  • train_dir是训练时的数据保存位置,放在最初建立的model文件夹下,因为我训练了多个模型,因此我在model文件夹下建立了多个子文件夹命名为model1等等,例如此例子我将该模型保存在model/model5中
  • pipeline是config文件的位置,我放在model文件下

3.3保存模型

最初建立的三个文件夹data是用来存放数据集的,而model是训练时的数据,主要包括各个检查点对应的能够生成模型的ckpt文件,以及训练过程中的信息,而export就是保存我们导出的模型

执行语句

python E:/tensorflow/models/research/object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path E:/tensorflow/hand_set/model/ssd_mobilenet_v2_coco.config  --trained_checkpoint_prefix E:/tensorflow/hand_set/model/model4/model.ckpt-5997  --output_directory E:/tensorflow/hand_set/export/model4/
  • export_inference_graph.py在下载的Tensorflow模型文件夹下
  • pipeline_config_path位置同上
  • trained_checkpoint_prefix选择效果最好的检查点来生成模型,一般选择最新的
  • output_directory模型保存路径
  • 最后生成的pb文件就是我们可以用的模型

3.4Tensorboard实时查看训练效果

win+r,输入cmd,执行语句,输入刚刚训练保存模型的绝对路径,记得输入绝对路径不容易出错

tensorboard --logdir=E:\tensorflow\hand_set\model\model5

然后在浏览器里输入https://localhost:6006 就可以查看训练效果了

4.测试结果

  • 在Tensorflow模型文件夹下tensorflow\models\research\object_detection 找到object_detection_tutorial.ipynb文件,将代码复制出来

  • 将模型修改为我们自己训练的模型地址,即pb文件的地址

    # Path to frozen detection graph. This is the actual model that is used for the object detection.
    PATH_TO_FROZEN_GRAPH =
    'E:/tensorflow/hand_set/export/model4/frozen_inference_graph.pb'
  • pbtxt文件地址也改为我们自己的文件地址

    PATH_TO_LABELS = os.path.join('E:/tensorflow/hand_set/data', 'object_detection.pbtxt')
  • 设置测试图片路径

    PATH_TO_TEST_IMAGES_DIR = 'test_images'
    TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
  • 也可以改为摄像头实时测试

    with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
            while  True:
                ret, image = capture.read()
                if ret is True:
                    image_np_expanded = np.expand_dims(image, axis=0)
                    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
                    boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
                    scores = detection_graph.get_tensor_by_name('detection_scores:0')
                    classes = detection_graph.get_tensor_by_name('detection_classes:0')
                    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
    
                    (boxes,scores,classes,num_detections)=sess.run([boxes, scores, classes, num_detections],
                                                                    feed_dict={image_tensor: image_np_expanded})
                    vis_util.visualize_boxes_and_labels_on_image_array(
                        image,
                        np.squeeze(boxes),
                        np.squeeze(classes).astype(np.int32),
                        np.squeeze(scores),
                        category_index,
                        min_score_thresh=0.6, #置信度
                        use_normalized_coordinates=True,
                        line_thickness=4
                    )
                    c = cv.waitKey(5)
                    if c == 27:  # ESC
                        break
                    cv.imshow("Demo", image)
                else:
                    break
            cv.waitKey(0)
            cv.destoryAllWindows()
  • 运行代码

本文地址:https://blog.csdn.net/Lianhaiyan_zero/article/details/107638516