使用imgaug库实现随机方式的图片增强(同时生成新的标签文件【精灵标注助手】)
程序员文章站
2022-04-08 14:57:30
...
前言
这段时间在做一个行为识别的项目,数据集是自己拍的视频转换成的RGB图像,标签文件是用精灵标注助手做的标签。由于数据量不够,且为了增加模型训练后的泛化能力,打算做一下数据增强,但是自己精力有限又比较懒,就想着做数据增强的同时也把标签文件自动转换了(也就是说把原图的标签文件中坐标点映射到增强后的图片上,把新的标签点以同样格式保存在xml文件中,按照yolov5的要求,图片名和xml文件名要保持一致),在网上查了很多帖子,没找到合适的方法,只好自己动手写了一个脚本。
直接上代码
data_aug.py:主程序入口,包含了解析xml文件,生成新的xml文件等功能
imgaug_utils.py:这个脚本是数据增强的部分
data_aug.py
精灵标注助手生成的xml文件格式:
我们看到,一个文件中可能含有多个item,每个item中包含了标签名称,标签框的四个坐标值,
首先我们需要把这几个值解析出来。
主函数入口
from xml.dom.minidom import parse, Document
from my_utils.imgaug_utils import get_inner_bbs
import numpy as np
def main_change(src_xml_path, src_img_dir, dst_img_dir, dst_xml_dir, p_number):
'''
:param p_number: Numbers of images to enhance
'''
print(f"src: {src_xml_path}")
dom = parse(src_xml_path)
root = dom.documentElement
img_name = root.getElementsByTagName("path")[0].childNodes[0].data
img_name = os.path.split(img_name)[-1].split(".")[0]
img_path = f"{src_img_dir}/{img_name}.jpg"
item = root.getElementsByTagName("item")
# label = root.getElementsByTagName("name")[0].childNodes[0].data
coor_list = []
# 便利循环item,获取每个box中包含的信息,保存到coor_list中
for box in item:
cls_name = box.getElementsByTagName("name")[0].childNodes[0].data
x1 = max(0, int(box.getElementsByTagName("xmin")[0].childNodes[0].data))
y1 = max(0, int(box.getElementsByTagName("ymin")[0].childNodes[0].data))
x2 = max(0, int(box.getElementsByTagName("xmax")[0].childNodes[0].data))
y2 = max(0, int(box.getElementsByTagName("ymax")[0].childNodes[0].data))
cls_name = trans_cls_name(cls_name)
coor_list.append([x1,y1,x2,y2,cls_name])
# 将coor_list中的信息,传入get_inner_bbs
aug_list = get_inner_bbs(img_path, dst_img_dir, np.array(coor_list), p_number)
if not aug_list:
return
# 将返回的坐标以及图片信息传入save_xml函数,保存到对应的xml文件中
for aug_info in aug_list:
save_xml(aug_info, dst_xml_dir)
这两个函数用来转换标签值和标签名
def trans_cls_name(name):
if name == "call":
return 0
elif name == "smoke":
return 1
elif name == "drink":
return 2
else:
raise ValueError(f"wrong class name! {name}")
def inv_trans_cls_name(value):
if value == 0:
return "call"
elif value == 1:
return "smoke"
elif value == 2:
return "drink"
else:
raise ValueError(f"wrong class number! {value}")
保存xml函数
def save_xml(aug_info, dst_xml_dir):
# 返回的aug_info包含两部分,坐标信息和图片信息,坐标信息的格式是[x1, y1, x2, y2, cls_name]
coor_array, img_info = aug_info
# 图片信息中包含增强后保存的图片路径,图片宽,高,通道数
img_name, img_h, img_w, img_c = list(map(str, img_info))
xml_name = os.path.split(img_name)[-1].split(".")[0]
# 1.创建DOM树对象
dom = Document()
# 2.创建根节点。每次都要用DOM对象来创建任何节点。
root_node = dom.createElement('root')
# 3.用DOM对象添加根节点
dom.appendChild(root_node)
path_node = dom.createElement('path')
root_node.appendChild(path_node)
path_text = dom.createTextNode(img_name)
path_node.appendChild(path_text)
outputs_node = dom.createElement('outputs')
root_node.appendChild(outputs_node)
object_node = dom.createElement('object')
outputs_node.appendChild(object_node)
# 遍历出每一个item中的坐标和cls name
for row_data in coor_array:
cls_name = inv_trans_cls_name(int(row_data[4]))
item_node = dom.createElement('item')
object_node.appendChild(item_node)
name_node = dom.createElement('name')
item_node.appendChild(name_node)
name_text = dom.createTextNode(cls_name)
name_node.appendChild(name_text)
bndbox_node = dom.createElement('bndbox')
item_node.appendChild(bndbox_node)
xmin_node = dom.createElement('xmin')
bndbox_node.appendChild(xmin_node)
xmin_text = dom.createTextNode(str(row_data[0]))
xmin_node.appendChild(xmin_text)
ymin_node = dom.createElement('ymin')
bndbox_node.appendChild(ymin_node)
ymin_text = dom.createTextNode(str(row_data[1]))
ymin_node.appendChild(ymin_text)
xmax_node = dom.createElement('xmax')
bndbox_node.appendChild(xmax_node)
xmax_text = dom.createTextNode(str(row_data[2]))
xmax_node.appendChild(xmax_text)
ymax_node = dom.createElement('ymax')
bndbox_node.appendChild(ymax_node)
ymax_text = dom.createTextNode(str(row_data[3]))
ymax_node.appendChild(ymax_text)
size_node = dom.createElement('size')
root_node.appendChild(size_node)
width_node = dom.createElement('width')
size_node.appendChild(width_node)
width_text = dom.createTextNode(img_w)
width_node.appendChild(width_text)
height_node = dom.createElement('height')
size_node.appendChild(height_node)
height_text = dom.createTextNode(img_h)
height_node.appendChild(height_text)
depth_node = dom.createElement('depth')
size_node.appendChild(depth_node)
depth_text = dom.createTextNode(img_c)
depth_node.appendChild(depth_text)
try:
with open(rf'{dst_xml_dir}/{xml_name}.xml', 'w') as f:
# writexml()第一个参数是目标文件对象,第二个参数是根节点的缩进格式,
# 第三个参数是其他子节点的缩进格式,
# 第四个参数制定了换行格式,第五个参数制定了xml内容的编码。
dom.writexml(f, indent='', addindent='\t', newl='\n', encoding='utf-8')
print(rf'dst: {dst_xml_dir}/{xml_name}.xml')
except Exception as err:
print('错误:{err}'.format(err=err))
imgaug_utils.py
get_inner_bbs()
图片增强主函数
import os
from PIL import Image
import numpy as np
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.bbs import BoundingBoxesOnImage
ia.seed(1)
GREEN = [0, 255, 0]
ORANGE = [255, 140, 0]
RED = [255, 0, 0]
def get_inner_bbs(image_path, dst_img_dir, array_info, p_numbers):
'''
:param image_path: src img path
:param dst_img_dir: img save path
:param coor_array: label coor array
:param p_numbers: Numbers of images to enhance
:return: [(bbs_array, img_info),
(bbs_array, img_info)]
'''
try:
# 这里将4个坐标值和类别名拆分开,后续再将新的坐标值和标签合为一个数组
assert array_info.shape[1] == 5
coor_array = array_info[:, :-1]
cls_array = array_info[:, -1]
image = Image.open(image_path)
image = np.array(image)
img_name = os.path.split(image_path)[-1].split(".")[0]
bbs = BoundingBoxesOnImage.from_xyxy_array(coor_array, shape=image.shape)
except Exception as e:
print(f"err:{e}")
print(array_info.shape)
print(image_path)
return None
# # Draw the original picture
# image_before = draw_bbs(image, bbs, 100)
# ia.imshow(image_before)
# Image augmentation sequence
# 此增强序列可以自行定义,API可以查询imgaug官方文档
seq = iaa.Sequential([
iaa.Fliplr(0.5),
iaa.Crop(percent=(0, 0.1)),
iaa.Sometimes(
0.5,
iaa.GaussianBlur(sigma=(0, 0.5))
),
# Strengthen or weaken the contrast in each image.
iaa.LinearContrast((0.75, 1.5)),
iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5),
# change illumination
iaa.Multiply((0.3, 1.2), per_channel=0.2),
# affine transformation
iaa.Affine(
scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
rotate=(-5, 5),
shear=(-8, 8)
)
], random_order=True) # apply augmenters in random order
res_list = []
# gen img and coor
try:
for epoch in range(p_numbers):
# 同时对图片和标签进行变换
image_aug, bbs_aug = seq(image=image, bounding_boxes=bbs)
# 这个方法可以将增强后标签框在图像外部的坐标,变为图片内
# bbs_aug = bbs_aug.remove_out_of_image().clip_out_of_image()
# # draw aug img and label
image_after = bbs_aug.draw_on_image(image_aug, size=2, color=[0, 0, 255])
ia.imshow(image_after)
# save img
h, w, c = image_aug.shape
img_aug_name = rf'{dst_img_dir}/{img_name}_{epoch}.jpg'
im = Image.fromarray(image_aug)
im.save(img_aug_name)
# 将新标签和类名合为一个数组,保存到列表
bbs_array = bbs_aug.to_xyxy_array()
result_array = np.column_stack((bbs_array, cls_array))
res_list.append([result_array, (img_aug_name, h, w, c)])
except Exception as e:
print(e)
print(img_aug_name)
return None
# return coor and img info
return res_list
填充和画框,可以用来观察转换后的坐标和图片
# Pad image with a 1px white and (BY-1)px black border
def _pad(image, by):
image_border1 = ia.augmenters.size.pad(image, top=1, right=1, bottom=1, left=1,
mode="constant", cval=255)
image_border2 = ia.augmenters.size.pad(image_border1, top=by-1, right=by-1,
bottom=by-1, left=by-1,
mode="constant", cval=0)
return image_border2
# Draw BBs on an image
# and before doing that, extend the image plane by BORDER pixels.
# Mark BBs inside the image plane with green color, those partially inside
# with orange and those fully outside with red.
def draw_bbs(image, bbs, border):
image_border = _pad(image, border)
for bb in bbs.bounding_boxes:
if bb.is_fully_within_image(image.shape):
color = GREEN
elif bb.is_partly_within_image(image.shape):
color = ORANGE
else:
color = RED
image_border = bb.shift(x=border, y=border)\
.draw_on_image(image_border, size=2, color=color)
return image_border
结语
如果觉得有用的朋友可以点个赞,如果代码中有错误或不足的地方也可以留言
附上git项目链接
https://github.com/lieweiAI/action-detection/tree/master/my_utils
上一篇: Django进阶之CSRF的解决
下一篇: 一个用于图片处理的工具类