欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

yolov3(keras-tf)多目标检测与数据标注

程序员文章站 2022-04-30 08:38:56
...

YOLOv3多目标检测

数据准备

如果要识别的目标能找到数据集,可下载;标注形式为:
image.jpg x1,y1,x2,y2,class1
下载voc、coco数据集可直接执行其代码生成dataset进行训练。

数据标注

这里采用opencv-python交互及多目标跟踪进行多目标标注和保存,生成yolo直接读取的dataset格式。
适用条件为目标连续出现易跟踪的视频。
使用方法:
输入视频文件,每个目标box左键点击左上、右下两点,再键盘输入1-9的标签,该帧boxs标注完之后,按esc进行视频多目标跟踪、保存。效果如下图:
yolov3(keras-tf)多目标检测与数据标注代码:

# -*- coding: utf-8
import cv2
import numpy as np
import os
'''
	data: 2020/04/03
	author: Jiang
	function:
		根据跟踪来提取目标正样本
		先手动标记左上、右下两点
		选择了一个box后按‘1-9’进行标签存取;所有box标记完后按esc进行保存,跟踪。
		如若偏差大,按 'esc' 标定;按 ‘q' 退出
	bug:
		需要增加一个按键命令进行无目标视频放映
'''

tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'MOSSE', 'CSRT']
def createTrackerByName(choose_tracker):
    # Set up tracker.
    # Instead of MIL, you can also use
    
    tracker_type = tracker_types[choose_tracker]
 
 
    if tracker_type == 'BOOSTING':				#基于adboost 抖动
        tracker = cv2.TrackerBoosting_create()
    if tracker_type == 'MIL':                             #效果差
        tracker = cv2.TrackerMIL_create()
    if tracker_type == 'KCF':				#太慢了
        tracker = cv2.TrackerKCF_create()
    if tracker_type == 'TLD':				#误报率高
        tracker = cv2.TrackerTLD_create()
    if tracker_type == 'MEDIANFLOW':			#效果好!!!
        tracker = cv2.TrackerMedianFlow_create()
    if tracker_type == "CSRT": 				#效果差
        tracker = cv2.TrackerCSRT_create()
    if tracker_type == "MOSSE":				#跟踪效果可以,帧率比MEDIANFLOW低,远距离时会扩大目标框
        tracker = cv2.TrackerMOSSE_create()
    return tracker

#跟踪,保存样本
def tracking_save(cap,frame,target_boxes,classes,path):
    global count_rects
    #判断视频流是摄像头还是文件
    path = str(path)
    if len(path) == 1:
        pos_name = path
    else:
        pos_name = path.split('.')[0]
    choose_tracker = 6
    if os.path.exists(pos_name) == False:
        os.mkdir(pos_name)  
    multiTracker = cv2.MultiTracker_create()
    # Initialize MultiTracker
    for bbox in target_boxes:
        #print(bbox)
        multiTracker.add(createTrackerByName(choose_tracker),frame, bbox)
    count = 0
    save = 0
    f = open(pos_name+'.txt','w')
    while True:
            # Read a new frame
            ok, frame = cap.read()
            save_frame = frame.copy()
            if not ok:
                break
            # Start timer
            timer = cv2.getTickCount()
            # Update tracker
            ok, boxes = multiTracker.update(frame)
            # Calculate Frames per second (FPS)
            fps = cv2.getTickFrequency() / (cv2.getTickCount() - timer)
            # Draw bounding box
            if ok:
                # Tracking success
                save_boxes = ""
                the_class = 0
                class_num = max(classes)#从1开始
                color_step = 255*3//class_num
                colors = []
                for i in range(class_num):
                    color_value = i*color_step
                    if color_value < 255:
                        b,g,r = color_value,0,0
                    elif 255 <=  color_value <= 255*2:
                        b,g,r = 255,color_value-255,0
                    else:
                        b,g,r = 255,255,color_value-255*2
                    color = (b,g,r)
                    if color == (0,0,0) :
                        color = (0,0,255)
                    elif color== (255,255,255):
                        color = (255,0,0)
                    colors.append(color)

                for i, newbox in enumerate(boxes):
                    th_class = classes[i]
                    color = colors[th_class - 1]
                    p1 = (int(newbox[0]), int(newbox[1]))
                    p2 = (int(newbox[0] + newbox[2]), int(newbox[1] + newbox[3]))
                    cv2.rectangle(frame, p1, p2, color, 2, 1)
                    save_boxes += " %s,%s,%s,%s,%s"%(p1[0],p1[1],p2[0],p2[1],th_class)
            
                #print (save_boxes,colors)
                count_rects += 1
                img_path =  pos_name + '/'+str(count_rects)+'.jpg'
                f.write(img_path + save_boxes +'\n')
                cv2.imwrite(img_path,save_frame)
                #cv2.imwrite(img_path,the_rect)#adboost样本pos保存
            else :
            # Tracking failure
                cv2.putText(frame, "Tracking failure detected", (100,80), cv2.FONT_HERSHEY_SIMPLEX, 0.75,(0,0,255),2)
    
            # Display tracker type on frame
            cv2.putText(frame, tracker_types[choose_tracker] + " Tracker", (100,20), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50,170,50),2)
            # Display FPS on frame
            cv2.putText(frame, "FPS : " + str(int(fps)), (100,50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50,170,50), 2)
            #drawMask(frame,bbox[0]+5,bbox[1]+5,bbox[2]-10,bbox[3]-10)  
            # Display result

            show_img = frame.copy()
            cv2.imshow("show", show_img)
            # cv2.imwrite("./save/"+str(save)+".jpg", show_img)
            # Exit if ESC pressed
            k = cv2.waitKey(10) & 0xff
            if k == 27 : 
                f.close()
                return 1
            elif k == ord('q'):
                #cv2.destroyWindow("Tracking")
                return 0
                f.close()
                
#鼠标回调函数s
def DrawCornerPoint(event,x,y,flags,param):
    global click_n,frame,corner_data
    if event == cv2.EVENT_LBUTTONDOWN :
        if click_n <= 1:
            corner_data.append((x,y))
            cv2.circle(frame,(x,y),2,(0,0,255),-1)#第一次打点
            click_n+=1
            print (click_n)
            cv2.imshow('show',frame)

#手动定位框
def get_bbox(cap):
        global click_n,frame,corner_data
        # Exit if video not opened.
        if not cap.isOpened():
            print("Could not open video")
            sys.exit()
    
        # Read first frame.
        ok, frame = cap.read()
        img0 = frame.copy()
        if not ok:
            print('Cannot read video file')
            sys.exit()
        else:#在该帧中确定目标box
            cv2.imshow('show',frame)
            cv2.setMouseCallback("show",DrawCornerPoint)
        boxes = []
        classes = []
        while 1:
            clicked_k = cv2.waitKey(20)
            if clicked_k > ord('0'):# 1 -9 可录10类
                classes.append(clicked_k - ord('0'))#类别从1开始
                print(classes)
                #跟踪框
                box = (corner_data[0][0], corner_data[0][1],\
                    corner_data[1][0] - corner_data[0][0], corner_data[1][1]-corner_data[0][1])#(x,y,w,h)
                boxes.append(box)
                click_n = 0
                corner_data = []#初始化box对角点
                print (clicked_k)
            
            elif clicked_k == 27:
                break
        print (boxes,classes)
        return img0,boxes,classes
        
#test 样本与txt是否对应
def test_samples(pathtxt):
    f = open(pathtxt,'r')
    lines = f.readlines()
    print (len(lines))
    f.close()

    for i,line in enumerate(lines):
        img_path = line.split(' ')[0]
        #img = cv2.imread(img_path)

        if os.path.exists(img_path) == False:
            print(i,line)
        else:
            cv2.imshow('show',cv2.imread(img_path))
            cv2.waitKey(100)
    cv2.destroyAllWindows()

if __name__ == "__main__":
  # yolo Row format: 	image_file_path box1 box2 ... boxN;
  #Box format: 		x_min,y_min,x_max,y_max,class_id
    path = '104.mp4'
    cap = cv2.VideoCapture(path)
    frame = 0#标志帧全局变量
    count_rects = 0#保存帧全局变量

    click_n = 0 #鼠标操作全局变量
    corner_data = []#对角点全局变量
    
    img0,boxes,classes = get_bbox(cap)
    key = tracking_save(cap,img0,boxes,classes,path)

    while key != 0:
        #按下了 esc 重新定位框
        count_rects -= 2#覆盖掉之前两帧
        click_n = 0
        corner_data = []
        img0,boxes,classes = get_bbox(cap)
        key = tracking_save(cap,img0,boxes,classes,path)
    #按下了 q ,退出
    cap.release()
    cv2.destroyAllWindows()
    test_samples(path.split('.')[0]+'.txt')

keras训练

相关标签: 应用