-
Notifications
You must be signed in to change notification settings - Fork 0
/
kcf_tracking.py
150 lines (130 loc) · 5.03 KB
/
kcf_tracking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import cv2
class MessageItem(object):
# 用于封装信息的类,包含图片和其他信息
def __init__(self, frame, message):
self._frame = frame
self._message = message
def getFrame(self):
# 图片信息
return self._frame
def getMessage(self):
# 文字信息,json格式
return self._message
class Tracker(object):
'''
追踪者模块,用于追踪指定目标
'''
def __init__(self, tracker_type="BOOSTING", draw_coord=True):
'''
初始化追踪器种类
'''
# 获得opencv版本
(major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
self.tracker_types = ['BOOSTING', 'MIL', 'KCF', 'TLD', 'MEDIANFLOW', 'GOTURN','MOSSE']
self.tracker_type = tracker_type
self.isWorking = False
self.draw_coord = draw_coord
# 构造追踪器
if int(major_ver) < 3:
self.tracker = cv2.Tracker_create(tracker_type)
else:
if tracker_type == 'BOOSTING':
self.tracker = cv2.TrackerBoosting_create()
if tracker_type == 'MIL':
self.tracker = cv2.TrackerMIL_create()
if tracker_type == 'KCF':
self.tracker = cv2.TrackerKCF_create()
if tracker_type == 'TLD':
self.tracker = cv2.TrackerTLD_create()
if tracker_type == 'MEDIANFLOW':
self.tracker = cv2.TrackerMedianFlow_create()
if tracker_type == 'GOTURN':
self.tracker = cv2.TrackerGOTURN_create()
if tracker_type == 'MOSSE':
self.tracker = cv2.TrackerMOSSE_create()
def initWorking(self, frame, box):
'''
追踪器工作初始化
frame:初始化追踪画面
box:追踪的区域
'''
if not self.tracker:
raise Exception("追踪器未初始化")
status = self.tracker.init(frame, box)
if not status:
raise Exception("追踪器工作初始化失败")
self.coord = box
self.isWorking = True
def track(self, frame):
'''
开启追踪
'''
message = None
if self.isWorking:
status, self.coord = self.tracker.update(frame)
print(self.coord)
if status:
message = {"coord": [((int(self.coord[0]), int(self.coord[1])),
(int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3])))]}
# print(message)
if self.draw_coord:
p1 = (int(self.coord[0]), int(self.coord[1]))
p2 = (int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3]))
cv2.rectangle(frame, p1, p2, (255, 0, 0), 2, 1)
cv2.circle(frame, (int((p2[0]+p1[0])/2),int((p2[1]+p1[1])/2)), 1, (0, 0, 255),4,lineType=cv2.LINE_8)
message['msg'] = "is tracking"
return MessageItem(frame, message),p1,p2
if __name__ == '__main__':
# 初始化视频捕获设备
gVideoDevice = cv2.VideoCapture(0)
gCapStatus, gFrame = gVideoDevice.read()
# 选择框/选帧
print("按 n 选择下一帧,按 y 选取当前帧")
while True:
if (gCapStatus == False):
print("捕获帧失败")
quit()
_key = cv2.waitKey(0) & 0xFF
if (_key == ord('n')):
gCapStatus, gFrame = gVideoDevice.read()
if (_key == ord('y')):
break
cv2.imshow("pick frame", gFrame)
# 框选感兴趣区域region of interest
cv2.destroyWindow("pick frame")
gROI = cv2.selectROI("ROI frame", gFrame, False)
if (not gROI):
print("空框选,退出")
quit()
# 初始化追踪器
# gTracker = Tracker(tracker_type="MOSSE")
gTracker = Tracker(tracker_type="KCF")
gTracker.initWorking(gFrame, gROI)
# 循环帧读取,开始跟踪
while True:
gCapStatus, gFrame = gVideoDevice.read()
if (gCapStatus):
# 展示跟踪图片
_item,p1,p2 = gTracker.track(gFrame)
cv2.imshow("track result", _item.getFrame())
if _item.getMessage():
# 打印跟踪数据
print(_item.getMessage())
else:
# 丢失,重新用初始ROI初始
print("丢失,重新使用初始ROI开始")
gTracker = Tracker(tracker_type="KCF")
# gTracker = Tracker(tracker_type="MOSSE")
gTracker.initWorking(gFrame, gROI)
_key = cv2.waitKey(1) & 0xFF
if (_key == ord('q')) | (_key == 27):
break
if (_key == ord('r')):
# 用户请求用初始ROI
print("用户请求用初始ROI")
gTracker = Tracker(tracker_type="KCF")
# gTracker = Tracker(tracker_type="MOSSE")
gTracker.initWorking(gFrame, gROI)
else:
print("捕获帧失败")
quit()