欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

数据转为YOLO的txt数据格式

程序员文章站 2022-07-01 23:38:09
通过两个类来转换import osfrom xml.etree.ElementTree import dumpimport jsonimport pprintimport sysimport argparseimport xml.etree.ElementTree as Etfrom xml.etree.ElementTree import Element, ElementTreeimport cv2class VOC: """ Handler Class for VO...

通过两个类来转换

import os
from xml.etree.ElementTree import dump
import json
import pprint
import sys
import argparse
import xml.etree.ElementTree as Et
from xml.etree.ElementTree import Element, ElementTree
import cv2

class VOC:
    """
    Handler Class for VOC PASCAL Format
    """

    def xml_indent(self, elem, level=0):
        i = "\n" + level * "\t"
        if len(elem):
            if not elem.text or not elem.text.strip():
                elem.text = i + "\t"
            if not elem.tail or not elem.tail.strip():
                elem.tail = i
            for elem in elem:
                self.xml_indent(elem, level + 1)
            if not elem.tail or not elem.tail.strip():
                elem.tail = i
        else:
            if level and (not elem.tail or not elem.tail.strip()):
                elem.tail = i

    def generate(self, data):
        try:

            xml_list = {}

            for key in data:
                element = data[key]

                xml_annotation = Element("annotation")

                xml_size = Element("size")
                xml_width = Element("width")
                xml_width.text = element["size"]["width"]
                xml_size.append(xml_width)

                xml_height = Element("height")
                xml_height.text = element["size"]["height"]
                xml_size.append(xml_height)

                xml_depth = Element("depth")
                xml_depth.text = element["size"]["depth"]
                xml_size.append(xml_depth)

                xml_annotation.append(xml_size)

                xml_segmented = Element("segmented")
                xml_segmented.text = "0"

                xml_annotation.append(xml_segmented)

                if int(element["objects"]["num_obj"]) < 1:
                    return False, "number of Object less than 1"

                for i in range(0, int(element["objects"]["num_obj"])):
                    xml_object = Element("object")
                    obj_name = Element("name")
                    obj_name.text = element["objects"][str(i)]["name"]
                    xml_object.append(obj_name)

                    obj_pose = Element("pose")
                    obj_pose.text = "Unspecified"
                    xml_object.append(obj_pose)

                    obj_truncated = Element("truncated")
                    obj_truncated.text = "0"
                    xml_object.append(obj_truncated)

                    obj_difficult = Element("difficult")
                    obj_difficult.text = "0"
                    xml_object.append(obj_difficult)

                    xml_bndbox = Element("bndbox")

                    obj_xmin = Element("xmin")
                    obj_xmin.text = element["objects"][str(i)]["bndbox"]["xmin"]
                    xml_bndbox.append(obj_xmin)

                    obj_ymin = Element("ymin")
                    obj_ymin.text = element["objects"][str(i)]["bndbox"]["ymin"]
                    xml_bndbox.append(obj_ymin)

                    obj_xmax = Element("xmax")
                    obj_xmax.text = element["objects"][str(i)]["bndbox"]["xmax"]
                    xml_bndbox.append(obj_xmax)

                    obj_ymax = Element("ymax")
                    obj_ymax.text = element["objects"][str(i)]["bndbox"]["ymax"]
                    xml_bndbox.append(obj_ymax)
                    xml_object.append(xml_bndbox)

                    xml_annotation.append(xml_object)

                self.xml_indent(xml_annotation)

                xml_list[key.split(".")[0]] = xml_annotation

            return True, xml_list

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg

    @staticmethod
    def save(xml_list, path):

        try:
            path = os.path.abspath(path)

            for key in xml_list:
                xml = xml_list[key]
                filepath = os.path.join(path, "".join([key, ".xml"]))
                ElementTree(xml).write(filepath)

            return True, None

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg

    @staticmethod
    def parse(path, img_path):
        try:

            (dir_path, dir_names, filenames) = next(os.walk(os.path.abspath(path)))

            data = {}

            for filename in filenames:

                xml = open(os.path.join(dir_path, filename), "r")

                tree = Et.parse(xml)
                root = tree.getroot()

                xml_size = root.find("size")
                size = {
                    "width": xml_size.find("width").text,
                    "height": xml_size.find("height").text,
                    "depth": xml_size.find("depth").text

                }

                objects = root.findall("object")
                if len(objects) == 0:
                    return False, "number object zero"

                obj = {
                    "num_obj": len(objects)
                }

                obj_index = 0
                for _object in objects:
                    tmp = {
                        "name": _object.find("name").text
                    }

                    xml_bndbox = _object.find("bndbox")
                    bndbox = {
                        "xmin": float(xml_bndbox.find("xmin").text),
                        "ymin": float(xml_bndbox.find("ymin").text),
                        "xmax": float(xml_bndbox.find("xmax").text),
                        "ymax": float(xml_bndbox.find("ymax").text)
                    }
                    tmp["bndbox"] = bndbox
                    obj[str(obj_index)] = tmp

                    obj_index += 1
                if obj_index < 1:
                    print('xml has no obj: {}'.format(os.path.join(dir_path, filename)))
                    continue
                if not os.path.exists(os.path.join(img_path, filename.replace('.xml', '.jpg'))):
                    print('img not exists : {}'.format(os.path.join(img_path, filename.replace('.xml', '.jpg'))))
                annotation = {
                    "img_path": os.path.join(img_path, filename.replace('.xml', '.jpg')),
                    "size": size,
                    "objects": obj
                }

                data[filename[:-4]] = annotation

            return True, data

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg
class YOLO:
    """
    Handler Class for UDACITY Format
    """

    def __init__(self, cls_list_path):
        with open(cls_list_path, 'r') as file:
            l = file.read().splitlines()

        self.cls_list = l

    def coordinateCvt2YOLO(self,size, box):
        dw = 1. / size[0]
        dh = 1. / size[1]

        # (xmin + xmax / 2)
        x = (box[0] + box[1]) / 2.0
        # (ymin + ymax / 2)
        y = (box[2] + box[3]) / 2.0

        # (xmax - xmin) = w
        w = box[1] - box[0]
        # (ymax - ymin) = h
        h = box[3] - box[2]

        x = x * dw
        w = w * dw
        y = y * dh
        h = h * dh
        return (round(x,10), round(y,10), round(w,10), round(h,10))

    def parse(self, label_path, img_path, img_type=".png"):
        try:

            (dir_path, dir_names, filenames) = next(os.walk(os.path.abspath(label_path)))

            data = {}

            progress_length = len(filenames)
            progress_cnt = 0
            printProgressBar(0, progress_length, prefix='\nYOLO Parsing:'.ljust(15), suffix='Complete', length=40)

            for filename in filenames:

                txt = open(os.path.join(dir_path, filename), "r")

                filename = filename.split(".")[0]

                img = Image.open(os.path.join(img_path, "".join([filename, img_type])))
                img_width = str(img.size[0])
                img_height = str(img.size[1])
                img_depth = 3

                size = {
                    "width": img_width,
                    "height": img_height,
                    "depth": img_depth
                }

                obj = {}
                obj_cnt = 0

                for line in txt:
                    elements = line.split(" ")
                    name_id = elements[0]

                    xminAddxmax = float(elements[1]) * (2.0 * float(img_width))
                    yminAddymax = float(elements[2]) * (2.0 * float(img_height))

                    w = float(elements[3]) * float(img_width)
                    h = float(elements[4]) * float(img_height)

                    xmin = (xminAddxmax - w) / 2
                    ymin = (yminAddymax - h) / 2
                    xmax = xmin + w
                    ymax = ymin + h

                    bndbox = {
                        "xmin": float(xmin),
                        "ymin": float(ymin),
                        "xmax": float(xmax),
                        "ymax": float(ymax)
                    }


                    obj_info = {
                        "name": name_id,
                        "bndbox": bndbox
                    }

                    obj[str(obj_cnt)] =obj_info
                    obj_cnt += 1

                obj["num_obj"] =  obj_cnt

                data[filename] = {
                    "size": size,
                    "objects": obj
                }

            return True, data

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg

    def generate(self, data):

        try:
            result = {}

            for key in data:
                img_width = int(data[key]["size"]["width"])
                img_height = int(data[key]["size"]["height"])

                contents = ""

                for idx in range(0, int(data[key]["objects"]["num_obj"])):

                    xmin = data[key]["objects"][str(idx)]["bndbox"]["xmin"]
                    ymin = data[key]["objects"][str(idx)]["bndbox"]["ymin"]
                    xmax = data[key]["objects"][str(idx)]["bndbox"]["xmax"]
                    ymax = data[key]["objects"][str(idx)]["bndbox"]["ymax"]

                    b = (float(xmin), float(xmax), float(ymin), float(ymax))
                    bb = self.coordinateCvt2YOLO((img_width, img_height), b)
                    # print(key)
                    if data[key]["objects"][str(idx)]["name"] not in self.cls_list:
                        if 'limit' in data[key]["objects"][str(idx)]["name"]:
                            data[key]["objects"][str(idx)]["name"] = 'limit'
                        elif 'van' in data[key]["objects"][str(idx)]["name"]:
                            data[key]["objects"][str(idx)]["name"] = 'car'
                        elif 'rider' in data[key]["objects"][str(idx)]["name"]:
                            data[key]["objects"][str(idx)]["name"] = 'person'
                        elif 'wan' in data[key]["objects"][str(idx)]["name"]:
                            print(key, '-----------------------------------------------------')
                            img = cv2.imread(data[key]['img_path'])
                            cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 255), 3)
                            cv2.imshow('test', img)
                            cv2.waitKey(2000)
                            continue
                        else:
                            if 'traffic light' in data[key]["objects"][str(idx)]["name"]  or 'trafiic light' in data[key]["objects"][str(idx)]["name"]:
                                continue
                            print(data[key]["objects"][str(idx)]["name"], 'not in cls list!------------')
                            continue

                    cls_id = self.cls_list.index(data[key]["objects"][str(idx)]["name"])

                    bndbox = "".join(["".join([str(e), " "]) for e in bb])
                    contents = "".join([contents, str(cls_id), " ", bndbox[:-1], "\n"])

                result[key] = contents

            return True, result

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg

    def save(self, data, save_path, img_path, img_type, manipast_path):

        try:
            with open(os.path.abspath(os.path.join(manipast_path, "manifast.txt")), "w") as manipast_file:
                
                for key in data:
                    manipast_file.write(os.path.abspath(os.path.join(img_path, "".join([key, img_type, "\n"]))))

                    with open(os.path.abspath(os.path.join(save_path, "".join([key, ".txt"]))), "w") as output_txt_file:
                        output_txt_file.write(data[key])

            return True, None

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg

查看标注文件在图片上的表现

def view_yolo_txt(txt_dir, img_dir):
    for idx, txtf in tqdm(enumerate(os.listdir(txt_dir))):
        txtp = os.path.join(txt_dir, txtf)
        with open(txtp, 'r') as f:
            res1 = [[float(j) for j in x.split()] for x in f.read().splitlines()]
            img = cv2.imread(os.path.join(img_dir, txtf.replace('.txt', '.jpg')))
            h, w, c = img.shape
            need_show = 0
            for obj in res1:
                cls, cx, cy, cw, ch = obj
                need_show = 1
                xmin = int(w * (cx - 0.5 * cw))
                ymin = int(h * (cy - 0.5 * ch))
                xmax = int(w * (cx + 0.5 * cw))
                ymax = int(h * (cy + 0.5 * ch))
                cv2.putText(img, class_idx_list[cls], (xmin, ymin-2), cv2.FONT_ITALIC, 1, (255, 255, 0), 1)
                cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
            img = cv2.resize(img, None, fx=0.5, fy=0.5)
            if need_show:
                print(os.path.join(img_dir, txtf.replace('.txt', '.jpg')))
                cv2.imshow('tt', img)
                cv2.waitKey(0)

本文地址:https://blog.csdn.net/weixin_44347020/article/details/107669064

相关标签: 深度学习