欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

目标检测之数据集增强(旋转)

程序员文章站 2024-03-19 23:02:28
...

项目背景:

最近在做国外身份证的项目,拍摄风格多样。需求:检测出身份证的位置并判断身份证的旋转角度,根据角度和位置校正身份证。

由此可以将一张标注样本旋转三个角度(90,180,270),做数据增广。相应地,xml文件的坐标和label要与旋转后的图像样本对应

目标检测之数据集增强(旋转)

xml文件如下,角度为0的label为TYPE:B_2D_ANGLE0, 相应地,若旋转90度,label为TYPE:B_2D_ANGLE90,180度与270度类似。

目标检测之数据集增强(旋转)

 

# -*- coding: utf-8 -*-
"""
Created on Fri Sep 27 13:53:47 2019

@author: mandy
"""
import os
import cv2
import time
import numpy as np
import xml.dom.minidom as xmldom

def parse_xml(fn):
    xml_file = xmldom.parse(fn)
    eles = xml_file.documentElement
    #print(eles.tagName)
    label = eles.getElementsByTagName("name")[0].firstChild.data
    xmin = eles.getElementsByTagName("xmin")[0].firstChild.data
    xmax = eles.getElementsByTagName("xmax")[0].firstChild.data
    ymin = eles.getElementsByTagName("ymin")[0].firstChild.data
    ymax = eles.getElementsByTagName("ymax")[0].firstChild.data
    #print(xmin, xmax, ymin, ymax)
    return label,xmin, ymin, xmax, ymax

def rotate_img(image,angle):   #自定义旋转函数,使用opencv自带的旋转函数旋转后会有黑边
    (h, w) = image.shape[:2]
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle, 1.0) #获得旋转矩阵
   # print(M)
    cos = np.abs(M[0, 0])
    sin = np.abs(M[0, 1])
    nW = int((h * sin) + (w * cos))
    nH = int((h * cos) + (w * sin))
    M[0, 2] += (nW / 2) - center[0]
    M[1, 2] += (nH / 2) - center[1]
    rotated=cv2.warpAffine(image, M, (nW, nH))    
    return rotated

def convert_coo(j,loc_re):   #修改旋转后的坐标及label
    angle_list=[0,90,180,270,0,90,180,270]
    if int(initial_angle) in angle_list:
        idx=angle_list.index(int(initial_angle))
        #print(idx)
        add_idx=int(j/90)
        idx+=add_idx
    w_list=[new_label+str(angle_list[idx])]
   # print(w_list)
    if j==90:
        for l in loc_re:
            x_r,y_r=l[1],w_src-l[0]
            w_list.append(x_r)
            w_list.append(y_r)
    elif j==180:
        for l in loc_re:
            x_r,y_r=w_src-l[0],h_src-l[1]
            w_list.append(x_r)
            w_list.append(y_r)
    elif j==270:
        for l in loc_re:
            x_r,y_r=h_src-l[1],l[0]
            w_list.append(x_r)
            w_list.append(y_r) 
    return w_list
p='../all_img'
files=[os.path.splitext(i)[0] for i in os.listdir(p) if '.xml' in i]
angle_list=[90,180,270]
for i in range(len(files)):
    print('-'*50)
    print(files[i]+'.jpg')
    print(files[i]+'.xml')
    img=cv2.imread(os.path.join(p,files[i]+'.jpg'))
    w_src=img.shape[1]
    h_src=img.shape[0]
    label,xmin, ymin, xmax, ymax=parse_xml(os.path.join(p,files[i]+'.xml'))
    new_label=label[:15] 
    initial_angle=label[15:]
    print(new_label,initial_angle)   
        
    loc_re=np.array([float(xmin), float(ymin), float(xmax), float(ymax)]).reshape((2,2))
    for j in angle_list:
        rotated = rotate_img(img, j)
        w_list=convert_coo(j,loc_re)  #得到的只是旋转前的左上右下对应的旋转后的坐标,并不是旋转后图像的左上右下坐标
        #坐上、右下坐标
        res=[None]*5
        res[0]=w_list[0]
        res[1]=((w_list[1]+w_list[3])/2)-(abs(w_list[3]-w_list[1])/2)
        res[2]=((w_list[2]+w_list[4])/2)-(abs(w_list[4]-w_list[2])/2)
        res[3]=((w_list[1]+w_list[3])/2)+(abs(w_list[3]-w_list[1])/2)
        res[4]=((w_list[2]+w_list[4])/2)+(abs(w_list[4]-w_list[2])/2)
        res=[str(k) for k in res]
        
        cv2.imwrite(os.path.join('rotate_img',files[i]+'_'+str(j)+'.jpg'),rotated)    
        f=open(os.path.join('rotate_txt',files[i]+'_'+str(j)+'.txt'),'w')
        f.write(' '.join(res))
   # print(xmin)
    

旋转后的label和坐标信息我保存在txt文件,然后再将txt转为xml文件,当然也可以同时转换

txt转xml代码:


from xml.dom import minidom
import os
import cv2
jpg_list=os.listdir('./addRotate/rotate_img')
#txt_list=os.listdir(r'label_class/')
for filename0 in jpg_list :
    print(filename0) 
    xml_filename=os.path.splitext(filename0)[0]
    jpg_dirtory=os.path.join(r'./addRotate/rotate_img' ,filename0)   #jpg文件路径
    txt_dirtory=os.path.join(r'./addRotate/rotate_txt',xml_filename+'.txt')  #txt文件路径
    img_name=jpg_dirtory.split('/')[-1]
    floder=jpg_dirtory#.split('/') 
    im = cv2.imread(jpg_dirtory)
    w = im.shape[1]  
    h = im.shape[0]
    d = im.shape[2]
    doc = minidom.Document()   #创建DOM树对象
    annotation = doc.createElement('annotation')   #创建子节点
    doc.appendChild(annotation)                    #annotation作为doc树的子节点 
    folder = doc.createElement('folder')            
    folder.appendChild(doc.createTextNode(floder))  #文本节点作为floder的子节点
    annotation.appendChild(folder)                 #folder作为annotation的子节点
    filename = doc.createElement('filename')
    filename.appendChild(doc.createTextNode(img_name))
    annotation.appendChild(filename)
    source = doc.createElement('source')
    database = doc.createElement('database')
    database.appendChild(doc.createTextNode("Unknown"))
    source.appendChild(database)
    annotation.appendChild(source)    
    size = doc.createElement('size')
    width = doc.createElement('width')
    width.appendChild(doc.createTextNode("%d" % w))
    size.appendChild(width)
    height = doc.createElement('height')
    height.appendChild(doc.createTextNode("%d" % h))
    size.appendChild(height)
    depth = doc.createElement('depth')
    depth.appendChild(doc.createTextNode("%d" % d))
    annotation.appendChild(size)
    segmented = doc.createElement('segmented')
    segmented.appendChild(doc.createTextNode("0"))
    annotation.appendChild(segmented)
    txtLabel = open(txt_dirtory, 'r')
    boxes = txtLabel.read().splitlines()  #splitlines代替readlines去掉换行符
    for box in boxes:
        box = box.split(' ')
        object = doc.createElement('object')
        nm = doc.createElement('name')
        nm.appendChild(doc.createTextNode(box[0]))
        object.appendChild(nm)
        pose = doc.createElement('pose')
        pose.appendChild(doc.createTextNode("Unspecified"))
        object.appendChild(pose)
        truncated = doc.createElement('truncated') 
        truncated.appendChild(doc.createTextNode("0")) 
        object.appendChild(truncated) 
        difficult = doc.createElement('difficult')
        difficult.appendChild(doc.createTextNode("0"))
        object.appendChild(difficult) 
        bndbox = doc.createElement('bndbox')   
        xmin = doc.createElement('xmin')    
        xmin.appendChild(doc.createTextNode(box[1]))
        bndbox.appendChild(xmin)    
        ymin = doc.createElement('ymin')        
        ymin.appendChild(doc.createTextNode(box[2]))       
        bndbox.appendChild(ymin)    
        xmax = doc.createElement('xmax')       
        xmax.appendChild(doc.createTextNode(box[3]))       
        bndbox.appendChild(xmax)    
        ymax = doc.createElement('ymax')    
        ymax.appendChild(doc.createTextNode(box[4]))       
        bndbox.appendChild(ymax) 
        object.appendChild(bndbox) 
        annotation.appendChild(object)
       
        p=r'Annotations/'+xml_filename+'.xml'   #xml文件保存路径    
        savefile = open(p, 'w')    
        savefile.write(doc.toprettyxml())   
        savefile.close()

旋转90度后的样本:

目标检测之数据集增强(旋转)

txt文件:

目标检测之数据集增强(旋转)

xml文件:

目标检测之数据集增强(旋转)