-
Notifications
You must be signed in to change notification settings - Fork 13
/
metrics.py
160 lines (134 loc) · 3.97 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import operator
import math
def is_valid_query(v):
num_pos = 0
num_neg = 0
for aid, label, score in v:
if label > 0:
num_pos += 1
else:
num_neg += 1
if num_pos > 0 and num_neg > 0:
return True
else:
return False
def get_num_valid_query(results):
num_query = 0
for k, v in results.items():
if not is_valid_query(v):
continue
num_query += 1
return num_query
def top_1_precision(results):
num_query = 0
top_1_correct = 0.0
for k, v in results.items():
if not is_valid_query(v):
continue
num_query += 1
sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
aid, label, score = sorted_v[0]
if label > 0:
top_1_correct += 1
if num_query > 0:
return top_1_correct / num_query
else:
return 0.0
def mean_reciprocal_rank(results):
num_query = 0
mrr = 0.0
for k, v in results.items():
if not is_valid_query(v):
continue
num_query += 1
sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
for i, rec in enumerate(sorted_v):
aid, label, score = rec
if label > 0:
mrr += 1.0 / (i + 1)
break
if num_query == 0:
return 0.0
else:
mrr = mrr / num_query
return mrr
def mean_average_precision(results):
num_query = 0
mvp = 0.0
for k, v in results.items():
if not is_valid_query(v):
continue
num_query += 1
sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
num_relevant_doc = 0.0
avp = 0.0
for i, rec in enumerate(sorted_v):
aid, label, score = rec
if label == 1:
num_relevant_doc += 1
precision = num_relevant_doc / (i + 1)
avp += precision
avp = avp / num_relevant_doc
mvp += avp
if num_query == 0:
return 0.0
else:
mvp = mvp / num_query
return mvp
def classification_metrics(results):
total_num = 0
total_correct = 0
true_positive = 0
positive_correct = 0
predicted_positive = 0
loss = 0.0;
for k, v in results.items():
for rec in v:
total_num += 1
aid, label, score = rec
if score > 0.5:
predicted_positive += 1
if label > 0:
true_positive += 1
loss += -math.log(score + 1e-12)
else:
loss += -math.log(1.0 - score + 1e-12);
if score > 0.5 and label > 0:
total_correct += 1
positive_correct += 1
if score < 0.5 and label < 0.5:
total_correct += 1
accuracy = float(total_correct) / total_num
precision = float(positive_correct) / (predicted_positive + 1e-12)
recall = float(positive_correct) / true_positive
F1 = 2.0 * precision * recall / (1e-12 + precision + recall)
return accuracy, precision, recall, F1, loss / total_num;
def top_k_precision(results, k=1):
num_query = 0
top_1_correct = 0.0
for key, v in results.items():
if not is_valid_query(v):
continue
num_query += 1
sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
if k == 1:
aid, label, score = sorted_v[0]
if label > 0:
top_1_correct += 1
elif k == 2:
aid1, label1, score1 = sorted_v[0]
aid2, label2, score2 = sorted_v[1]
if label1 > 0 or label2 > 0:
top_1_correct += 1
elif k == 5:
for vv in sorted_v[0:5]:
label = vv[1]
if label > 0:
top_1_correct += 1
break
else:
raise BaseException
if num_query > 0:
return top_1_correct/num_query
else:
return 0.0