-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
118 lines (107 loc) · 4.87 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import glob
import os
import numpy as np
import shutil
def write_results_no_score(filename, results):
"""Writes results in MOT style to filename."""
save_format = "{frame},{id},{x1},{y1},{w},{h},-1,-1,-1,-1\n"
with open(filename, "w") as f:
for frame_id, tlwhs, track_ids in results:
for tlwh, track_id in zip(tlwhs, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
line = save_format.format(
frame=frame_id,
id=track_id,
x1=round(x1, 1),
y1=round(y1, 1),
w=round(w, 1),
h=round(h, 1),
)
f.write(line)
def filter_targets(online_targets, aspect_ratio_thresh, min_box_area):
"""Removes targets not meeting threshold criteria.
Returns (list of tlwh, list of ids).
"""
online_tlwhs = []
online_ids = []
for t in online_targets:
tlwh = [t[0], t[1], t[2] - t[0], t[3] - t[1]]
tid = t[4]
vertical = tlwh[2] / tlwh[3] > aspect_ratio_thresh
if tlwh[2] * tlwh[3] > min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
return online_tlwhs, online_ids
def dti(txt_path, save_path, n_min=30, n_dti=20):
def dti_write_results(filename, results):
save_format = "{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n"
with open(filename, "w") as f:
for i in range(results.shape[0]):
frame_data = results[i]
frame_id = int(frame_data[0])
track_id = int(frame_data[1])
x1, y1, w, h = frame_data[2:6]
line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, w=w, h=h, s=-1)
f.write(line)
seq_txts = sorted(glob.glob(os.path.join(txt_path, "*.txt")))
# breakpoint()
for seq_txt in seq_txts:
seq_name = seq_txt.replace("\\", "/").split("/")[-1] ## To better play along with windows paths
seq_data = np.loadtxt(seq_txt, dtype=np.float64, delimiter=",")
min_id = int(np.min(seq_data[:, 1]))
max_id = int(np.max(seq_data[:, 1]))
seq_results = np.zeros((1, 10), dtype=np.float64)
for track_id in range(min_id, max_id + 1):
index = seq_data[:, 1] == track_id
tracklet = seq_data[index]
tracklet_dti = tracklet
if tracklet.shape[0] == 0:
continue
n_frame = tracklet.shape[0]
n_conf = np.sum(tracklet[:, 6] > 0.5)
if n_frame > n_min:
frames = tracklet[:, 0]
frames_dti = {}
for i in range(0, n_frame):
right_frame = frames[i]
if i > 0:
left_frame = frames[i - 1]
else:
left_frame = frames[i]
# disconnected track interpolation
if 1 < right_frame - left_frame < n_dti:
num_bi = int(right_frame - left_frame - 1)
right_bbox = tracklet[i, 2:6]
left_bbox = tracklet[i - 1, 2:6]
for j in range(1, num_bi + 1):
curr_frame = j + left_frame
curr_bbox = (curr_frame - left_frame) * (right_bbox - left_bbox) / (
right_frame - left_frame
) + left_bbox
frames_dti[curr_frame] = curr_bbox
num_dti = len(frames_dti.keys())
if num_dti > 0:
data_dti = np.zeros((num_dti, 10), dtype=np.float64)
for n in range(num_dti):
data_dti[n, 0] = list(frames_dti.keys())[n]
data_dti[n, 1] = track_id
data_dti[n, 2:6] = frames_dti[list(frames_dti.keys())[n]]
data_dti[n, 6:] = [1, -1, -1, -1]
tracklet_dti = np.vstack((tracklet, data_dti))
seq_results = np.vstack((seq_results, tracklet_dti))
save_seq_txt = os.path.join(save_path, seq_name)
seq_results = seq_results[1:]
seq_results = seq_results[seq_results[:, 0].argsort()]
dti_write_results(save_seq_txt, seq_results)
if __name__ == "__main__":
post_folder = "results/trackers/MOT17-val/1122_final_test_post"
pre_folder = "results/trackers/MOT17-val/1122_final_test"
if os.path.exists(post_folder):
print(f"Overwriting previous results in {post_folder}")
shutil.rmtree(post_folder)
shutil.copytree(pre_folder, post_folder)
post_folder_data = os.path.join(post_folder, "data")
dti(post_folder_data, post_folder_data)
print(f"Linear interpolation post-processing applied, saved to {post_folder_data}.")