yolov3(keras-tf)多目标检测与数据标注
程序员文章站
2022-04-30 08:38:56
...
数据准备
如果要识别的目标能找到数据集,可下载;标注形式为:image.jpg x1,y1,x2,y2,class1
下载voc、coco数据集可直接执行其代码生成dataset进行训练。
数据标注
这里采用opencv-python交互及多目标跟踪进行多目标标注和保存,生成yolo直接读取的dataset格式。
适用条件为目标连续出现易跟踪的视频。
使用方法:
输入视频文件,每个目标box左键点击左上、右下两点,再键盘输入1-9的标签,该帧boxs标注完之后,按esc进行视频多目标跟踪、保存。效果如下图:
代码:
# -*- coding: utf-8
import cv2
import numpy as np
import os
'''
data: 2020/04/03
author: Jiang
function:
根据跟踪来提取目标正样本
先手动标记左上、右下两点
选择了一个box后按‘1-9’进行标签存取;所有box标记完后按esc进行保存,跟踪。
如若偏差大,按 'esc' 标定;按 ‘q' 退出
bug:
需要增加一个按键命令进行无目标视频放映
'''
tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'MOSSE', 'CSRT']
def createTrackerByName(choose_tracker):
# Set up tracker.
# Instead of MIL, you can also use
tracker_type = tracker_types[choose_tracker]
if tracker_type == 'BOOSTING': #基于adboost 抖动
tracker = cv2.TrackerBoosting_create()
if tracker_type == 'MIL': #效果差
tracker = cv2.TrackerMIL_create()
if tracker_type == 'KCF': #太慢了
tracker = cv2.TrackerKCF_create()
if tracker_type == 'TLD': #误报率高
tracker = cv2.TrackerTLD_create()
if tracker_type == 'MEDIANFLOW': #效果好!!!
tracker = cv2.TrackerMedianFlow_create()
if tracker_type == "CSRT": #效果差
tracker = cv2.TrackerCSRT_create()
if tracker_type == "MOSSE": #跟踪效果可以,帧率比MEDIANFLOW低,远距离时会扩大目标框
tracker = cv2.TrackerMOSSE_create()
return tracker
#跟踪,保存样本
def tracking_save(cap,frame,target_boxes,classes,path):
global count_rects
#判断视频流是摄像头还是文件
path = str(path)
if len(path) == 1:
pos_name = path
else:
pos_name = path.split('.')[0]
choose_tracker = 6
if os.path.exists(pos_name) == False:
os.mkdir(pos_name)
multiTracker = cv2.MultiTracker_create()
# Initialize MultiTracker
for bbox in target_boxes:
#print(bbox)
multiTracker.add(createTrackerByName(choose_tracker),frame, bbox)
count = 0
save = 0
f = open(pos_name+'.txt','w')
while True:
# Read a new frame
ok, frame = cap.read()
save_frame = frame.copy()
if not ok:
break
# Start timer
timer = cv2.getTickCount()
# Update tracker
ok, boxes = multiTracker.update(frame)
# Calculate Frames per second (FPS)
fps = cv2.getTickFrequency() / (cv2.getTickCount() - timer)
# Draw bounding box
if ok:
# Tracking success
save_boxes = ""
the_class = 0
class_num = max(classes)#从1开始
color_step = 255*3//class_num
colors = []
for i in range(class_num):
color_value = i*color_step
if color_value < 255:
b,g,r = color_value,0,0
elif 255 <= color_value <= 255*2:
b,g,r = 255,color_value-255,0
else:
b,g,r = 255,255,color_value-255*2
color = (b,g,r)
if color == (0,0,0) :
color = (0,0,255)
elif color== (255,255,255):
color = (255,0,0)
colors.append(color)
for i, newbox in enumerate(boxes):
th_class = classes[i]
color = colors[th_class - 1]
p1 = (int(newbox[0]), int(newbox[1]))
p2 = (int(newbox[0] + newbox[2]), int(newbox[1] + newbox[3]))
cv2.rectangle(frame, p1, p2, color, 2, 1)
save_boxes += " %s,%s,%s,%s,%s"%(p1[0],p1[1],p2[0],p2[1],th_class)
#print (save_boxes,colors)
count_rects += 1
img_path = pos_name + '/'+str(count_rects)+'.jpg'
f.write(img_path + save_boxes +'\n')
cv2.imwrite(img_path,save_frame)
#cv2.imwrite(img_path,the_rect)#adboost样本pos保存
else :
# Tracking failure
cv2.putText(frame, "Tracking failure detected", (100,80), cv2.FONT_HERSHEY_SIMPLEX, 0.75,(0,0,255),2)
# Display tracker type on frame
cv2.putText(frame, tracker_types[choose_tracker] + " Tracker", (100,20), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50,170,50),2)
# Display FPS on frame
cv2.putText(frame, "FPS : " + str(int(fps)), (100,50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50,170,50), 2)
#drawMask(frame,bbox[0]+5,bbox[1]+5,bbox[2]-10,bbox[3]-10)
# Display result
show_img = frame.copy()
cv2.imshow("show", show_img)
# cv2.imwrite("./save/"+str(save)+".jpg", show_img)
# Exit if ESC pressed
k = cv2.waitKey(10) & 0xff
if k == 27 :
f.close()
return 1
elif k == ord('q'):
#cv2.destroyWindow("Tracking")
return 0
f.close()
#鼠标回调函数s
def DrawCornerPoint(event,x,y,flags,param):
global click_n,frame,corner_data
if event == cv2.EVENT_LBUTTONDOWN :
if click_n <= 1:
corner_data.append((x,y))
cv2.circle(frame,(x,y),2,(0,0,255),-1)#第一次打点
click_n+=1
print (click_n)
cv2.imshow('show',frame)
#手动定位框
def get_bbox(cap):
global click_n,frame,corner_data
# Exit if video not opened.
if not cap.isOpened():
print("Could not open video")
sys.exit()
# Read first frame.
ok, frame = cap.read()
img0 = frame.copy()
if not ok:
print('Cannot read video file')
sys.exit()
else:#在该帧中确定目标box
cv2.imshow('show',frame)
cv2.setMouseCallback("show",DrawCornerPoint)
boxes = []
classes = []
while 1:
clicked_k = cv2.waitKey(20)
if clicked_k > ord('0'):# 1 -9 可录10类
classes.append(clicked_k - ord('0'))#类别从1开始
print(classes)
#跟踪框
box = (corner_data[0][0], corner_data[0][1],\
corner_data[1][0] - corner_data[0][0], corner_data[1][1]-corner_data[0][1])#(x,y,w,h)
boxes.append(box)
click_n = 0
corner_data = []#初始化box对角点
print (clicked_k)
elif clicked_k == 27:
break
print (boxes,classes)
return img0,boxes,classes
#test 样本与txt是否对应
def test_samples(pathtxt):
f = open(pathtxt,'r')
lines = f.readlines()
print (len(lines))
f.close()
for i,line in enumerate(lines):
img_path = line.split(' ')[0]
#img = cv2.imread(img_path)
if os.path.exists(img_path) == False:
print(i,line)
else:
cv2.imshow('show',cv2.imread(img_path))
cv2.waitKey(100)
cv2.destroyAllWindows()
if __name__ == "__main__":
# yolo Row format: image_file_path box1 box2 ... boxN;
#Box format: x_min,y_min,x_max,y_max,class_id
path = '104.mp4'
cap = cv2.VideoCapture(path)
frame = 0#标志帧全局变量
count_rects = 0#保存帧全局变量
click_n = 0 #鼠标操作全局变量
corner_data = []#对角点全局变量
img0,boxes,classes = get_bbox(cap)
key = tracking_save(cap,img0,boxes,classes,path)
while key != 0:
#按下了 esc 重新定位框
count_rects -= 2#覆盖掉之前两帧
click_n = 0
corner_data = []
img0,boxes,classes = get_bbox(cap)
key = tracking_save(cap,img0,boxes,classes,path)
#按下了 q ,退出
cap.release()
cv2.destroyAllWindows()
test_samples(path.split('.')[0]+'.txt')