欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

对于标签在xml文件里面的数据集如何从xml里面提取标注框用于数据增强的笔记

程序员文章站 2024-03-19 23:11:16
...

比较丧………………自认为很好很棒的模型效果却不好~写了半天的数据增强代码还是不如开源的好用,哎~
写一篇博客记录一下吧~
我写的:从xml文件里面提取标注框,将原图中包含标注框在内的图像区域随机扩大,并resize到需要训练的size(目前效果不好,后面心情好了再改,下面会放代码)

# -*- coding: utf-8 -*-

import xml.etree.cElementTree as et
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random
dir='xml/'
def crop_img(dir,target_size,num):
    i=0
    for file in os.listdir(dir):
        if('xml'in file):
            s=dir+file
            tree=et.parse(s)
            root=tree.getroot()
            filename=root.find('filename').text
            for Object in root.findall('object'):
                bndbox=Object.find('bndbox')
                xmin=bndbox.find('xmin').text
                ymin=bndbox.find('ymin').text
                xmax=bndbox.find('xmax').text
                ymax=bndbox.find('ymax').text
                img=Image.open(dir+filename)
                #img=img0.rotate(random.randint(-180,180))
                size=img.size
                for count in range(0,num):
                    amin=random.randint(0,int(xmin))
                    bmin=random.randint(0,int(ymin))
                    cmax=random.randint(int(xmax),size[0])
                    dmax=random.randint(int(ymax),size[1])
                    box=(int(xmin),int(ymin),int(xmax),int(ymax))
                    box2=(amin,bmin,cmax,dmax)
                    print(box)
                    print(box2)
                    image=img.crop(box)
                    image2=img.crop(box2)
                    image2=image.resize((target_size,target_size),Image.ANTIALIAS)
                   # image1.save(dir+"chuli222_"+str(i)+".jpg")
                    image2.save(dir+"chuli_"+str(i)+".jpg")
                    i=i+1




crop_img(dir,target_size=599,num=1)

xml文件夹下的文件是xml和jpg同名且一一对应的。

对于标签在xml文件里面的数据集如何从xml里面提取标注框用于数据增强的笔记

生成的图片贼丑……

结合tensorflow开源的方案之后就很nice了:

# -*- coding: utf-8 -*-

import xml.etree.cElementTree as et
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import scipy.misc

def distort_color(image, color_ordering=0):
    if color_ordering == 0:
        image = tf.image.random_brightness(image, max_delta=32./255.)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_hue(image, max_delta=0.2)
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
    else:
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_brightness(image, max_delta=32./255.)
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
        image = tf.image.random_hue(image, max_delta=0.2)

    return tf.clip_by_value(image, 0.0, 1.0)
def preprocess_for_train(image, height, width, bbox):
    # 查看是否存在标注框。
    if bbox is None:
        bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
    if image.dtype != tf.float32:
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)

    # 随机的截取图片中一个块。
    bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
        tf.shape(image), bounding_boxes=bbox, min_object_covered=0.4)
    bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
        tf.shape(image), bounding_boxes=bbox, min_object_covered=0.4)
    distorted_image = tf.slice(image, bbox_begin, bbox_size)

    # 将随机截取的图片调整为神经网络输入层的大小。
    distorted_image = tf.image.resize_images(distorted_image, [height, width], method=np.random.randint(4))
    distorted_image = tf.image.random_flip_left_right(distorted_image)
    distorted_image = distort_color(distorted_image, np.random.randint(2))
    return distorted_image
def crop_img(dir,target_size,num):
    i=0
    for file in os.listdir(dir):
        if('xml'in file):
            s=dir+file
            tree=et.parse(s)
            root=tree.getroot()
            filename=root.find('filename').text
            for Object in root.findall('object'):
                bndbox=Object.find('bndbox')
                xmin=bndbox.find('xmin').text
                ymin=bndbox.find('ymin').text
                xmax=bndbox.find('xmax').text
                ymax=bndbox.find('ymax').text
                image_raw_data = tf.gfile.FastGFile(dir+filename, "rb").read()
                with tf.Session() as sess:
                    img_data = tf.image.decode_jpeg(image_raw_data)
                    img_data.set_shape([2560,1920,3])
                    a=img_data.get_shape()
                    print(a)
                    box=tf.constant([int(ymin)/int(a[1]),int(xmin)/int(a[0]),int(ymax)/int(a[1]),int(xmax)/int(a[0])], dtype=tf.float32, shape=[1, 1, 4])
                    print(box.eval())
                    for j in range(num):
                        result = preprocess_for_train(img_data, target_size, target_size, box)
                        scipy.misc.imsave(dir+"zengqiang_xc_"+str(i)+".jpg", result.eval())
                        i=i+1
    sess.close()



crop_img(dir="xml/",target_size=599,num=1)

哇哇哇,这里有一个地方调bug好久:(哭唧唧)

调试总结:
bug集中在这一句:

box=tf.constant([int(ymin)/int(a[1]),int(xmin)/int(a[0]),int(ymax)/int(a[1]),int(xmax)/int(a[0])], dtype=tf.float32, shape=[1, 1, 4])

刚开始的报错是我不知道这里的box是0~1之间的数,我给的是xmin之类的图像坐标,这个问题通过除以长宽解决(当然这中间还涉及没有对a[1]取int,报错 int 与dimension“/”操作不允许……)

接下来的报错原因更奇葩:
boxes: 形状 [batch, num_bounding_boxes, 4] 的三维矩阵, num_bounding_boxes 是标注框的数量,标注框由四个数字标示 [y_min, x_min, y_max, x_max],数组类型为float32。例如:tf.constant([[[0.05, 0.05, 0.9, 0.7], [0.35, 0.47, 0.5, 0.56]]]) shape 为 [1,2,4] 表示一张图片中的两个标注框;tf.constant([[[ 0. 0. 1. 1.]]]) 的 shape 为 [1,1,4]表示一张图片中的一个标注框

我前面的顺序写错了,我给了[xmin,ymin,xmax,ymax],,,内存访问冲突……
这个错误找了超级久……

现在终于搞定啦,还是比较开心的~上面的代码可以直接跑,只需要改参数就行啦:
你的图像和xml在哪个文件夹,dir就是啥

targetsize是目标尺寸

num是你准备把一张图片增强到几张
还有一个就是图像尺寸,我的是2560,1920,3,这一句可以删了~

8月2日补充:
tensorflow自带的resize之后图片会有1一些波纹,这些波纹非常影响图片的训练效果,换用opencv的resize就可以解决这一现象:

import  cv2
import numpy as np
import xml.etree.cElementTree as et
import os
import matplotlib.pyplot as plt
from PIL import Image
#import random
import tensorflow as tf
import scipy.misc

def distort_color(image, color_ordering):
    if color_ordering == 0:
        image = tf.image.random_brightness(image, max_delta=32./255.)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_hue(image, max_delta=0.2)
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
    else:
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_brightness(image, max_delta=32./255.)
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
        image = tf.image.random_hue(image, max_delta=0.2)

    return tf.clip_by_value(image, 0.0, 1.0)
def preprocess_for_train(image, height, width, bbox):
    # 查看是否存在标注框。
    if bbox is None:
        bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
    if image.dtype != tf.float32:
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)

    # 随机的截取图片中一个块。
    bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
        tf.shape(image), bounding_boxes=bbox, min_object_covered=1)
    #bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
        #tf.shape(image), bounding_boxes=bbox, min_object_covered=0.4)
    distorted_image = tf.slice(image, bbox_begin, bbox_size)

    # 将随机截取的图片调整为神经网络输入层的大小。
    distorted_image = tf.image.random_flip_left_right(distorted_image)
    distorted_image = tf.image.random_flip_up_down(distorted_image)
    distorted_image = distort_color(distorted_image, np.random.randint(2))
    scipy.misc.imsave('src/label_1/'+'zengqiang_xc.jpeg', distorted_image.eval())
    img=cv2.imread('src/label_1/'+'zengqiang_xc.jpeg')
    res = cv2.resize(img,(height,width), interpolation = cv2.INTER_AREA)

    return res
def crop_img(dir,target_size,num,savepath):
    i=0
    with tf.Session() as sess:
        for dir2 in os.listdir(dir):  
            for file in os.listdir(dir+dir2+"/"):  
                if('xml'in file):
                    s=dir+dir2+"/"+file
                    tree=et.parse(s)
                    root=tree.getroot()
                    filename=root.find('filename').text
                    for Object in root.findall('object'):
                        bndbox=Object.find('bndbox')
                        xmin=bndbox.find('xmin').text
                        ymin=bndbox.find('ymin').text
                        xmax=bndbox.find('xmax').text
                        ymax=bndbox.find('ymax').text
                        image_raw_data = tf.gfile.FastGFile(dir+dir2+"/"+filename, 'rb').read()
                        img_data = tf.image.decode_jpeg(image_raw_data)       
                        img_data.set_shape([1920,2560,3])
                        a=img_data.get_shape()
                        box=tf.constant([[[int(ymin)/int(a[0]),int(xmin)/int(a[1]),int(ymax)/int(a[0]),int(xmax)/int(a[1])]]], dtype=tf.float32)
                        for j in range(num):
                             res = preprocess_for_train(img_data, target_size, target_size, box)
                             cv2.imwrite(savepath+"/zengqiang_xc_"+str(i)+".jpg",res)
                             i=i+1