欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

YOLO代码解析(2)

程序员文章站 2022-06-21 16:26:00
...

下面介绍数据集处理部分代码。训练数据的处理主要包括两部分,一是在模型训练前使用数据预处理脚本preprocess_pascal_voc.py 对下载得到的数据集处理得到一个训练样本信息列表文件。二是在模型训练时根据训练样本的信息列表文件将数据读入到队列中,供模型训练时读取batch数据使用。

其他相关的部分请见:
YOLO代码解析(1) 代码总览与使用
YOLO代码解析(2) 数据处理
YOLO代码解析(3) 模型和损失函数
YOLO代码解析(4) 训练和测试代码

1.preprocess_pascal_voc.py :数据预处理

pascal_voc数据集的标注数据保存在xml中,每张图片对应一个单独的xml文件,文件内容如:

<annotation>
	<folder>VOC2007</folder>
	<filename>000001.jpg</filename>
	<source>
		<database>The VOC2007 Database</database>
		<annotation>PASCAL VOC2007</annotation>
		<image>flickr</image>
		<flickrid>341012865</flickrid>
	</source>
	<owner>
		<flickrid>Fried Camels</flickrid>
		<name>Jinky the Fruit Bat</name>
	</owner>
	<size>
		<width>353</width>
		<height>500</height>
		<depth>3</depth>
	</size>
	<segmented>0</segmented>
	<object>
		<name>dog</name>
		<pose>Left</pose>
		<truncated>1</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>48</xmin>
			<ymin>240</ymin>
			<xmax>195</xmax>
			<ymax>371</ymax>
		</bndbox>
	</object>
	<object>
		<name>person</name>
		<pose>Left</pose>
		<truncated>1</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>8</xmin>
			<ymin>12</ymin>
			<xmax>352</xmax>
			<ymax>498</ymax>
		</bndbox>
	</object>
</annotation>

本脚本功能是解析xml文件,对每一张图片,得到一条形如[image_path xmin1 ymin1 xmax1 ymax1 class_id1 xmin2 ymin2 xmax2 ymax2 class_id2] (如:/home/jerry/tensorflow-yolo/data/VOCdevkit/VOC2007/JPEGImages/009960.jpg 26 140 318 318 13 92 46 312 267 14)的记录,并写入到文件中。

xml解析代码:

def parse_xml(xml_file):
  """
  解析 xml_文件
  输入:xml文件路径
  返回:图像路径和对应的label信息
  """
  # 使用ElementTree解析xml文件
  tree = ET.parse(xml_file)
  root = tree.getroot()
  image_path = ''
  labels = []

  for item in root:
    if item.tag == 'filename':
      image_path = os.path.join(DATA_PATH, 'VOC2007/JPEGImages', item.text)
    elif item.tag == 'object':
      obj_name = item[0].text
      # 将objetc的名称转换为ID
      obj_num = classes_num[obj_name]
      # 依次得到Bbox的左上和右下点的坐标
      xmin = int(item[4][0].text)
      ymin = int(item[4][1].text)
      xmax = int(item[4][2].text)
      ymax = int(item[4][3].text)
      labels.append([xmin, ymin, xmax, ymax, obj_num])

  # 返回图像的路径和label信息(Bbox坐标和类别ID)
  return image_path, labels

def convert_to_string(image_path, labels):
  """
     将图像的路径和lable信息转为string
  """
  out_string = ''
  out_string += image_path
  for label in labels:
    for i in label:
      out_string += ' ' + str(i)
  out_string += '\n'
  return out_string

def main():
  out_file = open(OUTPUT_PATH, 'w')

  # 获取所有的xml标注文件的路径
  xml_dir = DATA_PATH + '/VOC2007/Annotations/'
  xml_list = os.listdir(xml_dir)
  xml_list = [xml_dir + temp for temp in xml_list]

  # 解析xml文件,得到图片名称和lables,并转换得到图片的路径
  for xml in xml_list:
    try:
      image_path, labels = parse_xml(xml)
      # 将解析得到的结果转为string并写入文件
      record = convert_to_string(image_path, labels)
      out_file.write(record)
    except Exception:
      pass

  out_file.close()

2. text_dataset.py:准备训练用batch数据

主要将在训练过程中将训练数据读入到队列中,起到缓存的作用。

class TextDataSet(DataSet):
  """TextDataSet
     对数据预处理中得到的data list文件进行处理
     text file format:
     image_path xmin1 ymin1 xmax1 ymax1 class1 xmin2 ymin2 xmax2 ymax2 class2
  """

  def __init__(self, common_params, dataset_params):
    """
    Args:
      common_params: A dict
      dataset_params: A dict
    """
    #process params
    self.data_path = str(dataset_params['path'])
    self.width = int(common_params['image_size'])
    self.height = int(common_params['image_size'])
    self.batch_size = int(common_params['batch_size'])
    self.thread_num = int(dataset_params['thread_num'])
    self.max_objects = int(common_params['max_objects_per_image'])

    #定义两个队列,一个存放训练样本的list,另个存放训练样本的数据(image & label)
    self.record_queue = Queue(maxsize=10000)
    self.image_label_queue = Queue(maxsize=512)

    self.record_list = []  

    # 读取经过数据预处理得到的 pascal_voc.txt
    input_file = open(self.data_path, 'r')

    for line in input_file:
      line = line.strip()
      ss = line.split(' ')
      ss[1:] = [float(num) for num in ss[1:]]  # 将坐标和类别ID转为float
      self.record_list.append(ss)

    self.record_point = 0
    self.record_number = len(self.record_list)

    # 计算每个epoch的batch数目
    self.num_batch_per_epoch = int(self.record_number / self.batch_size)

    # 启动record_processor进程
    t_record_producer = Thread(target=self.record_producer)
    t_record_producer.daemon = True 
    t_record_producer.start()

    # 启动record_customer进程
    for i in range(self.thread_num):
      t = Thread(target=self.record_customer)
      t.daemon = True
      t.start() 

  def record_producer(self):
    """record_queue 的processor
    """
    while True:
      if self.record_point % self.record_number == 0:
        random.shuffle(self.record_list)
        self.record_point = 0
      # 从record_list读取一条训练样本信息到record_queue
      self.record_queue.put(self.record_list[self.record_point])
      self.record_point += 1

  def record_process(self, record):
    """record 处理过程
    Args: record 
    Returns:
      image: 3-D ndarray
      labels: 2-D list [self.max_objects, 5] (xcenter, ycenter, w, h, class_num)
      object_num:  total object number  int 
    """
    image = cv2.imread(record[0])  # record[0]是image 的路径

    # 对图像做色彩空间变换和尺寸缩放
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    h = image.shape[0]
    w = image.shape[1]

    width_rate = self.width * 1.0 / w 
    height_rate = self.height * 1.0 / h

    # 尺寸调整到 (448,448)
    image = cv2.resize(image, (self.height, self.width))

    labels = [[0, 0, 0, 0, 0]] * self.max_objects

    i = 1
    object_num = 0

    while i < len(record):
      xmin = record[i]
      ymin = record[i + 1]
      xmax = record[i + 2]
      ymax = record[i + 3]
      class_num = record[i + 4]
     
      # 由于图片缩放过,对label坐标做同样处理
      xcenter = (xmin + xmax) * 1.0 / 2 * width_rate
      ycenter = (ymin + ymax) * 1.0 / 2 * height_rate

      box_w = (xmax - xmin) * width_rate
      box_h = (ymax - ymin) * height_rate

      labels[object_num] = [xcenter, ycenter, box_w, box_h, class_num]
      object_num += 1
      i += 5
      if object_num >= self.max_objects:
        break
    return [image, labels, object_num]

  def record_customer(self):
    """record queue的使用者
       取record queue中数据,经过处理后,送到image_label_queue中
    """
    while True:
      item = self.record_queue.get()
      out = self.record_process(item)
      self.image_label_queue.put(out)

下一篇:YOLO代码解析(3) 模型和损失函数