数据转为YOLO的txt数据格式
程序员文章站
2022-07-01 23:38:09
通过两个类来转换import osfrom xml.etree.ElementTree import dumpimport jsonimport pprintimport sysimport argparseimport xml.etree.ElementTree as Etfrom xml.etree.ElementTree import Element, ElementTreeimport cv2class VOC: """ Handler Class for VO...
通过两个类来转换
import os
from xml.etree.ElementTree import dump
import json
import pprint
import sys
import argparse
import xml.etree.ElementTree as Et
from xml.etree.ElementTree import Element, ElementTree
import cv2
class VOC:
"""
Handler Class for VOC PASCAL Format
"""
def xml_indent(self, elem, level=0):
i = "\n" + level * "\t"
if len(elem):
if not elem.text or not elem.text.strip():
elem.text = i + "\t"
if not elem.tail or not elem.tail.strip():
elem.tail = i
for elem in elem:
self.xml_indent(elem, level + 1)
if not elem.tail or not elem.tail.strip():
elem.tail = i
else:
if level and (not elem.tail or not elem.tail.strip()):
elem.tail = i
def generate(self, data):
try:
xml_list = {}
for key in data:
element = data[key]
xml_annotation = Element("annotation")
xml_size = Element("size")
xml_width = Element("width")
xml_width.text = element["size"]["width"]
xml_size.append(xml_width)
xml_height = Element("height")
xml_height.text = element["size"]["height"]
xml_size.append(xml_height)
xml_depth = Element("depth")
xml_depth.text = element["size"]["depth"]
xml_size.append(xml_depth)
xml_annotation.append(xml_size)
xml_segmented = Element("segmented")
xml_segmented.text = "0"
xml_annotation.append(xml_segmented)
if int(element["objects"]["num_obj"]) < 1:
return False, "number of Object less than 1"
for i in range(0, int(element["objects"]["num_obj"])):
xml_object = Element("object")
obj_name = Element("name")
obj_name.text = element["objects"][str(i)]["name"]
xml_object.append(obj_name)
obj_pose = Element("pose")
obj_pose.text = "Unspecified"
xml_object.append(obj_pose)
obj_truncated = Element("truncated")
obj_truncated.text = "0"
xml_object.append(obj_truncated)
obj_difficult = Element("difficult")
obj_difficult.text = "0"
xml_object.append(obj_difficult)
xml_bndbox = Element("bndbox")
obj_xmin = Element("xmin")
obj_xmin.text = element["objects"][str(i)]["bndbox"]["xmin"]
xml_bndbox.append(obj_xmin)
obj_ymin = Element("ymin")
obj_ymin.text = element["objects"][str(i)]["bndbox"]["ymin"]
xml_bndbox.append(obj_ymin)
obj_xmax = Element("xmax")
obj_xmax.text = element["objects"][str(i)]["bndbox"]["xmax"]
xml_bndbox.append(obj_xmax)
obj_ymax = Element("ymax")
obj_ymax.text = element["objects"][str(i)]["bndbox"]["ymax"]
xml_bndbox.append(obj_ymax)
xml_object.append(xml_bndbox)
xml_annotation.append(xml_object)
self.xml_indent(xml_annotation)
xml_list[key.split(".")[0]] = xml_annotation
return True, xml_list
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
@staticmethod
def save(xml_list, path):
try:
path = os.path.abspath(path)
for key in xml_list:
xml = xml_list[key]
filepath = os.path.join(path, "".join([key, ".xml"]))
ElementTree(xml).write(filepath)
return True, None
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
@staticmethod
def parse(path, img_path):
try:
(dir_path, dir_names, filenames) = next(os.walk(os.path.abspath(path)))
data = {}
for filename in filenames:
xml = open(os.path.join(dir_path, filename), "r")
tree = Et.parse(xml)
root = tree.getroot()
xml_size = root.find("size")
size = {
"width": xml_size.find("width").text,
"height": xml_size.find("height").text,
"depth": xml_size.find("depth").text
}
objects = root.findall("object")
if len(objects) == 0:
return False, "number object zero"
obj = {
"num_obj": len(objects)
}
obj_index = 0
for _object in objects:
tmp = {
"name": _object.find("name").text
}
xml_bndbox = _object.find("bndbox")
bndbox = {
"xmin": float(xml_bndbox.find("xmin").text),
"ymin": float(xml_bndbox.find("ymin").text),
"xmax": float(xml_bndbox.find("xmax").text),
"ymax": float(xml_bndbox.find("ymax").text)
}
tmp["bndbox"] = bndbox
obj[str(obj_index)] = tmp
obj_index += 1
if obj_index < 1:
print('xml has no obj: {}'.format(os.path.join(dir_path, filename)))
continue
if not os.path.exists(os.path.join(img_path, filename.replace('.xml', '.jpg'))):
print('img not exists : {}'.format(os.path.join(img_path, filename.replace('.xml', '.jpg'))))
annotation = {
"img_path": os.path.join(img_path, filename.replace('.xml', '.jpg')),
"size": size,
"objects": obj
}
data[filename[:-4]] = annotation
return True, data
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
class YOLO:
"""
Handler Class for UDACITY Format
"""
def __init__(self, cls_list_path):
with open(cls_list_path, 'r') as file:
l = file.read().splitlines()
self.cls_list = l
def coordinateCvt2YOLO(self,size, box):
dw = 1. / size[0]
dh = 1. / size[1]
# (xmin + xmax / 2)
x = (box[0] + box[1]) / 2.0
# (ymin + ymax / 2)
y = (box[2] + box[3]) / 2.0
# (xmax - xmin) = w
w = box[1] - box[0]
# (ymax - ymin) = h
h = box[3] - box[2]
x = x * dw
w = w * dw
y = y * dh
h = h * dh
return (round(x,10), round(y,10), round(w,10), round(h,10))
def parse(self, label_path, img_path, img_type=".png"):
try:
(dir_path, dir_names, filenames) = next(os.walk(os.path.abspath(label_path)))
data = {}
progress_length = len(filenames)
progress_cnt = 0
printProgressBar(0, progress_length, prefix='\nYOLO Parsing:'.ljust(15), suffix='Complete', length=40)
for filename in filenames:
txt = open(os.path.join(dir_path, filename), "r")
filename = filename.split(".")[0]
img = Image.open(os.path.join(img_path, "".join([filename, img_type])))
img_width = str(img.size[0])
img_height = str(img.size[1])
img_depth = 3
size = {
"width": img_width,
"height": img_height,
"depth": img_depth
}
obj = {}
obj_cnt = 0
for line in txt:
elements = line.split(" ")
name_id = elements[0]
xminAddxmax = float(elements[1]) * (2.0 * float(img_width))
yminAddymax = float(elements[2]) * (2.0 * float(img_height))
w = float(elements[3]) * float(img_width)
h = float(elements[4]) * float(img_height)
xmin = (xminAddxmax - w) / 2
ymin = (yminAddymax - h) / 2
xmax = xmin + w
ymax = ymin + h
bndbox = {
"xmin": float(xmin),
"ymin": float(ymin),
"xmax": float(xmax),
"ymax": float(ymax)
}
obj_info = {
"name": name_id,
"bndbox": bndbox
}
obj[str(obj_cnt)] =obj_info
obj_cnt += 1
obj["num_obj"] = obj_cnt
data[filename] = {
"size": size,
"objects": obj
}
return True, data
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
def generate(self, data):
try:
result = {}
for key in data:
img_width = int(data[key]["size"]["width"])
img_height = int(data[key]["size"]["height"])
contents = ""
for idx in range(0, int(data[key]["objects"]["num_obj"])):
xmin = data[key]["objects"][str(idx)]["bndbox"]["xmin"]
ymin = data[key]["objects"][str(idx)]["bndbox"]["ymin"]
xmax = data[key]["objects"][str(idx)]["bndbox"]["xmax"]
ymax = data[key]["objects"][str(idx)]["bndbox"]["ymax"]
b = (float(xmin), float(xmax), float(ymin), float(ymax))
bb = self.coordinateCvt2YOLO((img_width, img_height), b)
# print(key)
if data[key]["objects"][str(idx)]["name"] not in self.cls_list:
if 'limit' in data[key]["objects"][str(idx)]["name"]:
data[key]["objects"][str(idx)]["name"] = 'limit'
elif 'van' in data[key]["objects"][str(idx)]["name"]:
data[key]["objects"][str(idx)]["name"] = 'car'
elif 'rider' in data[key]["objects"][str(idx)]["name"]:
data[key]["objects"][str(idx)]["name"] = 'person'
elif 'wan' in data[key]["objects"][str(idx)]["name"]:
print(key, '-----------------------------------------------------')
img = cv2.imread(data[key]['img_path'])
cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 255), 3)
cv2.imshow('test', img)
cv2.waitKey(2000)
continue
else:
if 'traffic light' in data[key]["objects"][str(idx)]["name"] or 'trafiic light' in data[key]["objects"][str(idx)]["name"]:
continue
print(data[key]["objects"][str(idx)]["name"], 'not in cls list!------------')
continue
cls_id = self.cls_list.index(data[key]["objects"][str(idx)]["name"])
bndbox = "".join(["".join([str(e), " "]) for e in bb])
contents = "".join([contents, str(cls_id), " ", bndbox[:-1], "\n"])
result[key] = contents
return True, result
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
def save(self, data, save_path, img_path, img_type, manipast_path):
try:
with open(os.path.abspath(os.path.join(manipast_path, "manifast.txt")), "w") as manipast_file:
for key in data:
manipast_file.write(os.path.abspath(os.path.join(img_path, "".join([key, img_type, "\n"]))))
with open(os.path.abspath(os.path.join(save_path, "".join([key, ".txt"]))), "w") as output_txt_file:
output_txt_file.write(data[key])
return True, None
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
查看标注文件在图片上的表现
def view_yolo_txt(txt_dir, img_dir):
for idx, txtf in tqdm(enumerate(os.listdir(txt_dir))):
txtp = os.path.join(txt_dir, txtf)
with open(txtp, 'r') as f:
res1 = [[float(j) for j in x.split()] for x in f.read().splitlines()]
img = cv2.imread(os.path.join(img_dir, txtf.replace('.txt', '.jpg')))
h, w, c = img.shape
need_show = 0
for obj in res1:
cls, cx, cy, cw, ch = obj
need_show = 1
xmin = int(w * (cx - 0.5 * cw))
ymin = int(h * (cy - 0.5 * ch))
xmax = int(w * (cx + 0.5 * cw))
ymax = int(h * (cy + 0.5 * ch))
cv2.putText(img, class_idx_list[cls], (xmin, ymin-2), cv2.FONT_ITALIC, 1, (255, 255, 0), 1)
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
img = cv2.resize(img, None, fx=0.5, fy=0.5)
if need_show:
print(os.path.join(img_dir, txtf.replace('.txt', '.jpg')))
cv2.imshow('tt', img)
cv2.waitKey(0)
本文地址:https://blog.csdn.net/weixin_44347020/article/details/107669064
上一篇: AJAX简单应用实例-弹出层
下一篇: asp两组字符串数据比较合并相同数据
推荐阅读
-
android 解析json数据格式的方法
-
详解Obejective-C中将JSON数据转为模型的方法
-
Excel2010自定义数据格式让数据以不同的形态进行显示
-
详解javascript中对数据格式化的思考
-
sql2005 数据库转为sql2000数据库的方法(数据导出导入)
-
Python导入txt数据到mysql的方法
-
实现纯真IP txt转mdb数据库的方法
-
python3 json数据格式的转换(dumps/loads的使用、dict to str/str to dict、json字符串/字典的相互转换)
-
在Python的struct模块中进行数据格式转换的方法
-
python读取csv和txt数据转换成向量的实例