欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

训练所用到的一些代码

程序员文章站 2022-04-29 19:02:00
...

整理一些深度学习训练会用到的代码:(从网上找到的,非本人所写,有进行一些修改)
如有冒犯到各位大佬,麻烦直接联系我,谢谢
(数量太多,我找不到原来的链接了)

针对.xml标签文件进行处理: files.py

files.py

#!/usr/lib64/python2.7
# -*- coding:utf-8 -*-

import os
import xml.dom.minidom
import xml.etree.ElementTree
import sys  
#reload(sys)  
#sys.setdefaultencoding('utf8')   


xmldir = './Annotations' 
#你的xml文件的路經,注意最后一定要有'/'
 
for xmlfile in os.listdir(xmldir):
    xmlname = os.path.splitext(xmlfile)[0]
    print(xmlname)
    #读取 xml 文件
    dom = xml.dom.minidom.parse(os.path.join(xmldir,xmlfile))
    root = dom.documentElement
    for i in range(len(root.getElementsByTagName('xmin'))):
        xmin = float(root.getElementsByTagName('xmin')[i].firstChild.data)
        xmax = float(root.getElementsByTagName('xmax')[i].firstChild.data)
        if xmin>= xmax:
            xmax = xmin + 1
            root.getElementsByTagName('xmax')[i].firstChild.data = str(xmax)
            with open(os.path.join(xmldir, xmlfile), 'w')  as fh:
                dom.writexml(fh)
        else:
            print(xmlname)
    #root.getElementsByTagName('width')[0].firstChild.data = '512'
    #root.getElementsByTagName('height')[0].firstChild.data = '288'
'''
    for i in range(len(root.getElementsByTagName('xmin'))):
        
        xmin = float(root.getElementsByTagName('xmin')[i].firstChild.data)
        xmax = float(root.getElementsByTagName('xmax')[i].firstChild.data)
        ymin = float(root.getElementsByTagName('ymin')[i].firstChild.data)
        ymax = float(root.getElementsByTagName('ymax')[i].firstChild.data)
    #获取标签对的名字,并为其赋一个新值
    #root.getElementsByTagName('filename')[0].firstChild.data = xmlname + '.jpg'
    #root.getElementsByTagName('width')[0].firstChild.data = '512'
    #root.getElementsByTagName('height')[0].firstChild.data = '288'

        root.getElementsByTagName('xmin')[i].firstChild.data = str(round(xmin/3.75))
        root.getElementsByTagName('xmax')[i].firstChild.data = str(round(xmax/3.75))
        root.getElementsByTagName('ymin')[i].firstChild.data = str(round(ymin/3.75))
        root.getElementsByTagName('ymax')[i].firstChild.data = str(round(ymax/3.75))
    #root.getElementsByTagName('height')[0].firstChild.data ='288'
'''
    #修改并保存文件
    #xml_specific = xmldir + xmlfile 
    #with open(os.path.join(xmldir, xmlfile), 'w')  as fh:
        #dom.writexml(fh)

标题遍历标签,得到数据集所有类别:getclass.py

#-*- coding: UTF-8 -*-
#查找VOC数据格式中有几个类别

#数据格式是这样
#+--Annotations
#+--JPEGImages

#最后输出
#+--class.txt

import os
import os.path as osp
import cv2
import xml.etree.ElementTree as ET

thisDir=osp.abspath(os.path.dirname(__file__))
imDir=osp.join(thisDir,'JPEGImages')
xmlDir=osp.join(thisDir,'Annotations')
resDir=osp.join(thisDir,'class.txt')
if os.path.exists(resDir):
    os.remove(resDir)
txt=open(resDir,'w')

subList=os.listdir(imDir)
subList=[x for x in subList if x.split('.')[-1]=='jpg']

print('iamge num is',len(subList))
clasNameTot=[]
for name in subList:
    fileName=osp.join(xmlDir,name[:-4]+'.xml')
    tree = ET.parse(fileName)
    objs = tree.findall('object')
    for ix , obj in enumerate(objs):
        clasName = obj.find('name').text
        if clasName not in clasNameTot:
            clasNameTot.append(clasName)
print(clasNameTot)
for line in clasNameTot:
    txt.write(line+'\n')
txt.close()

从图片文件夹中获得所有图片名字(去除.jgp),得到一个txt文档: get _imgnum.py

import os
dir1='./reshape_images'#图片文件存放地址
txt1 = 'train.txt'#图片文件名存放txt文件地址
f1 = open(txt1,'a')#打开文件流
for filename in os.listdir(dir1):
    f1.write(filename.rstrip('.jpg'))#只保存名字,去除后缀.jpg
    f1.write("\n")#换行
f1.close()#关闭文件流

从标签文件夹中得到某一个类有哪些,输出txt文档: get_num.py

#-*- coding: UTF-8 -*-
#查找VOC数据格式中有几个类别

#数据格式是这样
#+--Annotations
#+--JPEGImages

#最后输出
#+--class.txt

import os
import os.path as osp
import cv2
import xml.etree.ElementTree as ET

thisDir=osp.abspath(os.path.dirname(__file__))
imDir=osp.join(thisDir,'JPEGImages')
xmlDir=osp.join(thisDir,'Annotations')
resDir=osp.join(thisDir,'meter_num.txt')
if os.path.exists(resDir):
    os.remove(resDir)
txt=open(resDir,'w')

subList=os.listdir(imDir)
subList=[x for x in subList if x.split('.')[-1]=='jpg']

print('iamge num is',len(subList))
clasNameTot=[]
for name in subList:
    fileName=osp.join(xmlDir,name[:-4]+'.xml')
    tree = ET.parse(fileName)
    objs = tree.findall('object')
    
    for ix , obj in enumerate(objs):
        clasName = obj.find('name').text
        if clasName == 'meter_oil':
           print(fileName)
            
            #name = tree.findall('annotation')
            #for i in enumerate(annotation):
               #ame = .find('filename').text
                #print(i)
            #clasNameTot.append(ame)
           
           #for line in fileName:
           txt.write(fileName+'\n')
txt.close()

从标签文件中获得txt文本中名字所对应的标签: get_xml.py

# -*- coding:utf-8 -*-
# Author: Agent Xu
import os
import glob
import shutil
from PIL import Image
 
#指定找到文件后,另存为的文件夹路径
outDir = os.path.abspath('./train_xml')
 
#指定TXT文本的位置
txtDir1 = os.path.abspath('train.txt')
 
imgname1=[]
with open(txtDir1,'r') as f:
	for line in f:
		imgname1.append(line.strip('\n'))
 
#文件夹路径
imageDir2 = os.path.abspath('./ground_oil_stone')
image2 = []
imgname2 = []
imageList2 = glob.glob(os.path.join(imageDir2, '*.xml'))
 
for item in imageList2:
    image2.append(os.path.basename(item))
 
for item in image2:
    (temp1, temp2) = os.path.splitext(item)
    imgname2.append(temp1)
 
#匹配
for item1 in imgname1:
    for item2 in imgname2:
        if item1 == item2:
            dir = image2[imgname2.index(item2)]
            xmlname=os.path.join(imageDir2,dir)
            print(xmlname)
            shutil.copy(xmlname,outDir)

对所有图片进行统一尺寸的处理:reshape.py

# -*- coding: utf-8 -*-
"""
Created on Thu Aug 23 16:06:35 2018
@author: libo
"""
from PIL import Image
import os
 
 
def image_resize(image_path, new_path):           # 统一图片尺寸
    print('============>>修改图片尺寸')
    for img_name in os.listdir(image_path):
        img_path = image_path + "/" + img_name    # 获取该图片全称
        image = Image.open(img_path)              # 打开特定一张图片
        image = image.resize((512, 512))          # 设置需要转换的图片大小
        # process the 1 channel image
        image.save(new_path + '/'+ img_name)
    print("end the processing!")
 
 
if __name__ == '__main__':
    print("ready for ::::::::  ")
    ori_path = r"/home/zmy/CenterNet/data/oilspill/JPEGImages"                # 输入图片的文件夹路径
    new_path = '/home/zmy/CenterNet/data/oilspill/reshape_images'                   # resize之后的文件夹路径
    image_resize(ori_path, new_path)

获得所有图片(大小统一)的均值和方差:mean.py

import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
#from scipy.misc import imread
import imread
 
filepath = r'/home/zmy/CenterNet/data/oilspill/reshape_images'  # 数据集目录
pathDir = os.listdir(filepath)
 
R_channel = 0
G_channel = 0
B_channel = 0
for idx in range(len(pathDir)):
    filename = pathDir[idx]
    img = imread(os.path.join(filepath, filename)) / 255.0
    R_channel = R_channel + np.sum(img[:, :, 0])
    G_channel = G_channel + np.sum(img[:, :, 1])
    B_channel = B_channel + np.sum(img[:, :, 2])
 
num = len(pathDir) * 512 * 512  # 这里(512,512)是每幅图片的大小,所有图片尺寸都一样
R_mean = R_channel / num
G_mean = G_channel / num
B_mean = B_channel / num
 
R_channel = 0
G_channel = 0
B_channel = 0
for idx in range(len(pathDir)):
    filename = pathDir[idx]
    img = imread(os.path.join(filepath, filename)) / 255.0
    R_channel = R_channel + np.sum((img[:, :, 0] - R_mean) ** 2)
    G_channel = G_channel + np.sum((img[:, :, 1] - G_mean) ** 2)
    B_channel = B_channel + np.sum((img[:, :, 2] - B_mean) ** 2)
 
R_var = np.sqrt(R_channel / num)
G_var = np.sqrt(G_channel / num)
B_var = np.sqrt(B_channel / num)
print("R_mean is %f, G_mean is %f, B_mean is %f" % (R_mean, G_mean, B_mean))
print("R_var is %f, G_var is %f, B_var is %f" % (R_var, G_var, B_var))

按比例划分数据集:spilt_dataset.py

##深度学习过程中,需要制作训练集和验证集、测试集。

import os, random, shutil
def moveFile(fileDir):
        pathDir = os.listdir(fileDir)    #取图片的原始路径
        filenumber=len(pathDir)
        rate=0.2    #自定义抽取图片的比例,比方说100张抽10张,那就是0.1
        picknumber=int(filenumber*rate) #按照rate比例从文件夹中取一定数量图片
        sample = random.sample(pathDir, picknumber)  #随机选取picknumber数量的样本图片
        print (sample)
        for name in sample:
                shutil.move(fileDir+name, tarDir+name)
        return

if __name__ == '__main__':
	fileDir = "./reshape_images/"    #源图片文件夹路径
	tarDir = './test/'    #移动到新的文件夹路径
	moveFile(fileDir)

xml格式转json: xml2json.py

import xml.etree.ElementTree as ET
import os
import json

coco = dict()
coco['images'] = []
coco['type'] = 'instances'
coco['annotations'] = []
coco['categories'] = []

category_set = dict()
image_set = set()

category_item_id = 0
image_id = 20200000000
annotation_id = 0

def addCatItem(name):
    global category_item_id
    category_item = dict()
    category_item['supercategory'] = 'none'
    category_item_id += 1
    category_item['id'] = category_item_id
    category_item['name'] = name
    coco['categories'].append(category_item)
    category_set[name] = category_item_id
    return category_item_id

def addImgItem(file_name, size):
    global image_id
    if file_name is None:
        raise Exception('Could not find filename tag in xml file.')
    if size['width'] is None:
        raise Exception('Could not find width tag in xml file.')
    if size['height'] is None:
        raise Exception('Could not find height tag in xml file.')
    image_id += 1
    image_item = dict()
    image_item['id'] = image_id
    image_item['file_name'] = file_name
    image_item['width'] = size['width']
    image_item['height'] = size['height']
    coco['images'].append(image_item)
    image_set.add(file_name)
    return image_id

def addAnnoItem(object_name, image_id, category_id, bbox):
    global annotation_id
    annotation_item = dict()
    annotation_item['segmentation'] = []
    seg = []
    #bbox[] is x,y,w,h
    #left_top
    seg.append(bbox[0])
    seg.append(bbox[1])
    #left_bottom
    seg.append(bbox[0])
    seg.append(bbox[1] + bbox[3])
    #right_bottom
    seg.append(bbox[0] + bbox[2])
    seg.append(bbox[1] + bbox[3])
    #right_top
    seg.append(bbox[0] + bbox[2])
    seg.append(bbox[1])

    annotation_item['segmentation'].append(seg)

    annotation_item['area'] = bbox[2] * bbox[3]
    annotation_item['iscrowd'] = 0
    annotation_item['ignore'] = 0
    annotation_item['image_id'] = image_id
    annotation_item['bbox'] = bbox
    annotation_item['category_id'] = category_id
    annotation_id += 1
    annotation_item['id'] = annotation_id
    coco['annotations'].append(annotation_item)

def parseXmlFiles(xml_path): 
    for f in os.listdir(xml_path):
        if not f.endswith('.xml'):
            continue
        
        bndbox = dict()
        size = dict()
        current_image_id = None
        current_category_id = None
        file_name = None
        size['width'] = None
        size['height'] = None
        size['depth'] = None

        xml_file = os.path.join(xml_path, f)
        print(xml_file)

        tree = ET.parse(xml_file)
        root = tree.getroot()
        if root.tag != 'annotation':
            raise Exception('pascal voc xml root element should be annotation, rather than {}'.format(root.tag))

        #elem is <folder>, <filename>, <size>, <object>
        for elem in root:
            current_parent = elem.tag
            current_sub = None
            object_name = None
            
            if elem.tag == 'folder':
                continue
            
            if elem.tag == 'filename':
                file_name = elem.text
                if file_name in category_set:
                    raise Exception('file_name duplicated')
                
            #add img item only after parse <size> tag
            elif current_image_id is None and file_name is not None and size['width'] is not None:
                if file_name not in image_set:
                    current_image_id = addImgItem(file_name, size)
                    print('add image with {} and {}'.format(file_name, size))
                else:
                    raise Exception('duplicated image: {}'.format(file_name)) 
            #subelem is <width>, <height>, <depth>, <name>, <bndbox>
            for subelem in elem:
                bndbox ['xmin'] = None
                bndbox ['xmax'] = None
                bndbox ['ymin'] = None
                bndbox ['ymax'] = None
                
                current_sub = subelem.tag
                if current_parent == 'object' and subelem.tag == 'name':
                    object_name = subelem.text
                    if object_name not in category_set:
                        current_category_id = addCatItem(object_name)
                    else:
                        current_category_id = category_set[object_name]

                elif current_parent == 'size':
                    if size[subelem.tag] is not None:
                        raise Exception('xml structure broken at size tag.')
                    size[subelem.tag] = int(subelem.text)

                #option is <xmin>, <ymin>, <xmax>, <ymax>, when subelem is <bndbox>
                for option in subelem:
                    if current_sub == 'bndbox':
                        if bndbox[option.tag] is not None:
                            raise Exception('xml structure corrupted at bndbox tag.')
                        bndbox[option.tag] = int(option.text)

                #only after parse the <object> tag
                if bndbox['xmin'] is not None:
                    if object_name is None:
                        raise Exception('xml structure broken at bndbox tag')
                    if current_image_id is None:
                        raise Exception('xml structure broken at bndbox tag')
                    if current_category_id is None:
                        raise Exception('xml structure broken at bndbox tag')
                    bbox = []
                    #x
                    bbox.append(bndbox['xmin'])
                    #y
                    bbox.append(bndbox['ymin'])
                    #w
                    bbox.append(bndbox['xmax'] - bndbox['xmin'])
                    #h
                    bbox.append(bndbox['ymax'] - bndbox['ymin'])
                    print('add annotation with {},{},{},{}'.format(object_name, current_image_id, current_category_id, bbox))
                    addAnnoItem(object_name, current_image_id, current_category_id, bbox )

if __name__ == '__main__':
    xml_path = './test_xml'
    json_file = 'test.json'
    parseXmlFiles(xml_path)
    json.dump(coco, open(json_file, 'w'))
相关标签: 深度学习