-
Notifications
You must be signed in to change notification settings - Fork 19
/
eval_metrics.py
executable file
·155 lines (128 loc) · 5.61 KB
/
eval_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from __future__ import print_function, absolute_import
import numpy as np
import copy
from collections import defaultdict
import sys
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100):
"""Evaluation with cuhk03 metric
Key: one image for each gallery identity is randomly sampled for each query identity.
Random sampling is performed N times (default: N=100).
"""
num_q, num_g = distmat.shape
if num_g < max_rank:
max_rank = num_g
print("Note: number of gallery samples is quite small, got {}".format(num_g))
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
# compute cmc curve for each query
all_cmc = []
all_AP = []
num_valid_q = 0. # number of valid query
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
order = indices[q_idx]
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
keep = np.invert(remove)
# compute cmc curve
orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
if not np.any(orig_cmc):
# this condition is true when query identity does not appear in gallery
continue
kept_g_pids = g_pids[order][keep]
g_pids_dict = defaultdict(list)
for idx, pid in enumerate(kept_g_pids):
g_pids_dict[pid].append(idx)
cmc, AP = 0., 0.
for repeat_idx in range(N):
mask = np.zeros(len(orig_cmc), dtype=np.bool)
for _, idxs in g_pids_dict.items():
# randomly sample one image for each gallery person
rnd_idx = np.random.choice(idxs)
mask[rnd_idx] = True
masked_orig_cmc = orig_cmc[mask]
_cmc = masked_orig_cmc.cumsum()
_cmc[_cmc > 1] = 1
cmc += _cmc[:max_rank].astype(np.float32)
# compute AP
num_rel = masked_orig_cmc.sum()
tmp_cmc = masked_orig_cmc.cumsum()
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
tmp_cmc = np.asarray(tmp_cmc) * masked_orig_cmc
AP += tmp_cmc.sum() / num_rel
cmc /= N
AP /= N
all_cmc.append(cmc)
all_AP.append(AP)
num_valid_q += 1.
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
return all_cmc, mAP
def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
img_names=None, g_a_pids=None, g_a_camids=None):
"""Evaluation with market1501 metric
Key: for each query identity, its gallery images from the same camera view are discarded.
"""
num_q, num_g = distmat.shape
if num_g < max_rank:
max_rank = num_g
print("Note: number of gallery samples is quite small, got {}".format(num_g))
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
print("dist mat:", np.sort(distmat)[0][:5])
print("all matches:", np.sum(matches))
if img_names is not None:
img_names = img_names[indices]
# compute cmc curve for each query
all_cmc = []
all_AP = []
num_valid_q = 0. # number of valid query
tot_found = 0
tot_pres = 0
for q_idx in range(num_q):
print("matches", matches[q_idx])
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# check frac of matches retained
par_matches = [(g_pids[i] == q_pid) for i in range(len(g_pids))]
all_matches = [(g_a_pids[i] == q_pid) for i in range(len(g_a_pids))]
num_found = np.sum(par_matches)
num_pres = np.sum(all_matches)
print("found:", num_found, "present:", num_pres)
tot_found += num_found
tot_pres += num_pres
# # remove gallery samples that have the same pid and camid with query
# order = indices[q_idx]
# remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
# keep = np.invert(remove)
# compute cmc curve
orig_cmc = matches[q_idx] # binary vector, positions with value 1 are correct matches
if not np.any(orig_cmc):
# this condition is true when query identity does not appear in gallery
continue
cmc = orig_cmc.cumsum()
cmc[cmc > 1] = 1
all_cmc.append(cmc[:max_rank])
num_valid_q += 1.
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
num_rel = orig_cmc.sum()
tmp_cmc = orig_cmc.cumsum()
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
AP = tmp_cmc.sum() / num_pres
all_AP.append(AP)
# assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
return all_cmc, all_AP, num_valid_q, tot_found, tot_pres
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False,
img_names=None, g_a_pids=None, g_a_camids=None):
if use_metric_cuhk03:
return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
else:
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
img_names=img_names, g_a_pids=g_a_pids, g_a_camids=g_a_camids)