欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

总结训练SSD时的小trick

程序员文章站 2024-03-16 23:35:58
...

删除xml中某些obj:

import os  
import os.path  
import xml.dom.minidom  
from xml.dom.minidom import parse
import xml.dom.minidom
import os,shutil
import numpy as np
import cv2
from PIL import Image, ImageDraw

path="/data_1/SSD/cpls/cpls/xml_ls"  

files=os.listdir(path)
s=[]  
k=0
c=0
n=0
s=" " 
for xmlFile in files:
    k=k+1
    c=0
    print k
    print xmlFile
    if not os.path.isdir(xmlFile): 
        #print xmlFile  
        dom=xml.dom.minidom.parse(os.path.join(path,xmlFile))
        root=dom.documentElement  
        objects = root.getElementsByTagName("object")
        for object_ in objects:
            #a=object_.getElementsByTagName("name")[0].childNodes[0].nodeValue
            #if a=="fadongjifanghuzhuangzhi":
             #   print xmlFile  
             #   name=object_.getElementsByTagName('name')
              #  n0=name[0]
               # n0.firstChild.data='xcjly'
            xmin = object_.getElementsByTagName("xmin")[0].childNodes[0].nodeValue
            xmax = object_.getElementsByTagName("xmax")[0].childNodes[0].nodeValue
            ymin = object_.getElementsByTagName("ymin")[0].childNodes[0].nodeValue
            ymax = object_.getElementsByTagName("ymax")[0].childNodes[0].nodeValue
            #print (int(xmax)-int(xmin))
            #c=c+1
        #if(c>1 or c==0):
         #   s=s+xmlFile+" "
          #  n=n+1
            if (int(xmax)-int(xmin))>100 or (int(ymax)-int(ymin))>100:
                object_.parentNode.removeChild(object_)
               
            with open(os.path.join(path, xmlFile), 'w') as fh:
                dom.writexml(fh)
#print n
#print s

修改xml和img文件名:

import os
import shutil

## datadir AND savedir CAN NOT BE SAME
img_datadir="old/JPEGImages/"
img_savedir="new/JPEGImages/"

xml_datadir="old/Annotations/"
xml_savedir="new/Annotations/"

imglist=os.listdir(img_datadir)
k=0
for img in imglist:
 name=img
 k=k+1
 print k
# zzname=img.split("_",2)
 img_datapath=img_datadir+name
 xml_datapath=xml_datadir+name.replace("jpg","xml")
# img_save_name='mhq_171215_'+str(k)+"_"+zzname[2]
# xml_save_name=img_save_name.replace(".jpg",".xml")

 img_save_name='mhq_181105_'+str(k)+".jpg"
 xml_save_name=img_save_name.replace(".jpg",".xml")

 img_savepath=img_savedir+img_save_name 
 xml_savepath=xml_savedir+xml_save_name 

 #if os.path.exists(img_datapath):
 shutil.copy(img_datapath, img_savepath)
  
 #if os.path.exists(xml_datapath):
 shutil.copy(xml_datapath, xml_savepath)  

修改xml中filename:

# -*- coding:utf-8 -*-
import os  
import os.path  
import xml.dom.minidom  
from xml.dom.minidom import parse
import xml.dom.minidom
import os,shutil
#import numpy as np
#import cv2
import urllib
from PIL import Image, ImageDraw

path="new/Annotations/"  
files=os.listdir(path)
s=[]
num=0
for xmlFile in files:
    num+=1
    print(num)
    imgname=xmlFile.replace(".xml",".jpg")
    if not os.path.isdir(xmlFile): 
        print (xmlFile)
        dom=xml.dom.minidom.parse(os.path.join(path,xmlFile))
        root=dom.documentElement  
        #filename = root.getElementsByTagName("filename").childNodes[0].nodeValue
        #print (filename)
        filename1=root.getElementsByTagName('filename')
        n0=filename1[0]
        print (n0.firstChild.data)

	
        a=imgname
        n0.firstChild.data=a
        print (n0.firstChild.data)
    

        with open(os.path.join(path, xmlFile), 'w') as fh:
            dom.writexml(fh)

修改xml中objname:

import os  
import os.path  
import xml.dom.minidom  
from xml.dom.minidom import parse
import xml.dom.minidom
import os,shutil
import numpy as np
import cv2
from PIL import Image, ImageDraw

path="/xml"  

files=os.listdir(path)
s=[]  
k=0
for xmlFile in files:
    k=k+1
    print k
    print xmlFile
    if not os.path.isdir(xmlFile): 
        #print xmlFile  
        dom=xml.dom.minidom.parse(os.path.join(path,xmlFile))
        root=dom.documentElement  
        objects = root.getElementsByTagName("object")
        for object_ in objects:
            #a=object_.getElementsByTagName("name")[0].childNodes[0].nodeValue
            #if a=="fdjh":
            #    print xmlFile  
            name=object_.getElementsByTagName('name')
            n0=name[0]
            n0.firstChild.data='cpls'
               
            with open(os.path.join(path, xmlFile), 'w') as fh:
                dom.writexml(fh)

xml可视化:

from xml.dom.minidom import parse
import matplotlib.pyplot as plt
import xml.dom.minidom
import os,shutil
import matplotlib  
import numpy as np
import cv2
from PIL import Image, ImageDraw
##########################################################
root="/data_1/SSD/caffe/data/VOCdevkit/mydataset/"
#only need to change these
##########################################################
#annroot=root+'2/'
#picroot=root+'1/'
#annroot=root+'xml/'
#picroot=root+'img/'
annroot=root+'Annotations/'
picroot=root+'JPEGImages/'
anns=os.listdir(annroot)
imgs=os.listdir(picroot)

labelmap=["cpls"]

colormap=["red" , "green", "blue" , "yellow", "pink" , "olive" , "deeppink" , "darkorange", "purple", "cyan","red" , "green", "blue" , "yellow", "pink" , "olive" , "deeppink" , "darkorange", "purple", "cyan","red"]

def mkdir(path): 
	folder = os.path.exists(path)
	if not folder:
		os.makedirs(path)

number = 0
nn=0
for ann in anns:
    number += 1
    print (number)
    print (ann)
    annpath=annroot+ann
    picpath=picroot+ann.replace("xml","jpg")
    im = Image.open(picpath)
    img = cv2.imread(picpath)
    draw = ImageDraw.Draw(im)
    DOMTree = xml.dom.minidom.parse(annpath)
    collection = DOMTree.documentElement
    objects = collection.getElementsByTagName("object")
    labelsss = ""
    for object_ in objects:
        #print (object_)
        a=object_.getElementsByTagName("name")[0].childNodes[0].nodeValue
        k=a.split('.',1)
        kk=k[0]        
        b=str(kk) 	
        for i in range(0,len(labelmap)):
            label = labelmap[i]
            print (label)
	    
            if b == label:
	        nn+= 1
		if label not in labelsss:
    		    labelsss+= label +"_"
                bndboxs = object_.getElementsByTagName("bndbox")
                for bndbox in bndboxs:
                    xmin = bndbox.getElementsByTagName('xmin')[0].childNodes[0].nodeValue
                    ymin = bndbox.getElementsByTagName('ymin')[0].childNodes[0].nodeValue
                    xmax = bndbox.getElementsByTagName('xmax')[0].childNodes[0].nodeValue
                    ymax = bndbox.getElementsByTagName('ymax')[0].childNodes[0].nodeValue
                    xtmp1=xmin.split('.',1)
                    xmin1=xtmp1[0]
                    xtmp2= xmax.split('.', 1)
                    xmax1 = xtmp2[0]
                    xtmp3=ymin.split('.',1)
                    ymin1=xtmp3[0]
                    xtmp4=ymax.split('.',1)
                    ymax1=xtmp4[0]
                xmin = int(xmin1)
                ymin = int(ymin1)
                xmax = int(xmax1)
                ymax = int(ymax1)
                if xmin<0:
                    xmin=0
                if ymin<0:
                    ymin=0
                sp=img.shape
                if xmax>sp[1]:
                    xmax=sp[1]
                if ymax>sp[0]:
                    ymax=sp[0]


		roiimg=img[ymin: ymax, xmin:xmax]
		save_op = root+'chcc/'+label +"_"+"/"
		mkdir(save_op)
		saveopath = save_op+str(nn)+"_"+ann.replace("xml","jpg")		
		cv2.imwrite(saveopath,roiimg)

                draw.rectangle((xmin, ymin, xmax, ymax), outline = colormap[i])
                draw.rectangle((xmin-1, ymin-1, xmax-1, ymax-1), outline = colormap[i])
            	draw.rectangle((xmin+1, ymin+1, xmax+1, ymax+1), outline = colormap[i])
            	draw.rectangle((xmin-2, ymin-2, xmax-2, ymax-2), outline = colormap[i])
            	draw.rectangle((xmin+2, ymin+2, xmax+2, ymax+2), outline = colormap[i])
            	draw.rectangle((xmin-3, ymin-3, xmax-3, ymax-3), outline = colormap[i])
            	draw.rectangle((xmin+3, ymin+3, xmax+3, ymax+3), outline = colormap[i])	
            	break


        label_has=0
        for label in labelmap:
            if b != label:
                label_has = 1
        if not label_has:
            print (ann+"======"+b+"============================")
       

    
    save_p = root+'check/'+labelsss+"/"
    savepath = save_p+ann.replace("xml","jpg")
    mkdir(save_p)
    im.save(savepath)
        #cv2.imwrite(savepath,roiimg)

根据全部数据集Annotations随机生成训练数据集和测试数据集的txt:

make_txt.sh

cd /Annotations

find ./ -name "*" |sort >>../ImageSets/Main/all.txt

cd ../ImageSets/Main

python2 make_txt.py

make_txt.py

import os

num = 0
for tmpline in open("./all.txt"):
    num += 1
#    print num
    if num==1:
        continue
    tmp = tmpline.strip('\n')
    tmp = tmp.replace(".xml","")
    tmp = tmp.replace("./","")
    tmpp=tmp.split('_',3)
    number= int(tmpp[2])
    aa = number%10
    if aa==0 or aa==1:
	with open('/data_2/data/train_data/ss_big_obj_object_bdf/ssd_big_obj/ImageSets/Main/test.txt', 'a') as f:
	    f.writelines(tmp + '\n')
    elif aa==2:
	with open('/data_2/data/train_data/ss_big_obj_object_bdf/ssd_big_obj/ImageSets/Main/val.txt', 'a') as f:
	    f.writelines(tmp + '\n')
    else:
	with open('/data_2/data/train_data/ss_big_obj_object_bdf/ssd_big_obj/ImageSets/Main/trainval.txt', 'a') as f:
	    f.writelines(tmp + '\n')