-
Notifications
You must be signed in to change notification settings - Fork 10
/
eval_script.py
105 lines (83 loc) · 4.01 KB
/
eval_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
""" Official evaluation script for v1.0 of the ComplexWebQuestions dataset. """
import argparse
import json
import unicodedata
import re
import pandas as pd
def compare_span_to_answer(spans, answers, question, question_annotated=None):
""" Compares one answers to spans, multiple matches are possible
"""
if len(spans) == 0:
return []
found_answers = pd.DataFrame(columns=['span', 'answer', 'span_index'])
spans_series = pd.Series(spans)
pre_proc_answers = []
answers = [answer.lower().strip() for answer in answers]
for answer in answers:
proc_answer = unicodedata.normalize('NFKD', answer).encode('ascii', 'ignore').decode(encoding='UTF-8')
# removing common endings such as "f.c."
proc_answer = re.sub(r'\W', ' ', proc_answer).lower().strip()
# removing The, a, an from begining of answer as proposed by SQuAD dataset answer comparison
if proc_answer.startswith('the '):
proc_answer = proc_answer[4:]
if proc_answer.startswith('a '):
proc_answer = proc_answer[2:]
if proc_answer.startswith('an '):
proc_answer = proc_answer[3:]
pre_proc_answers.append(proc_answer)
question = question.lower().strip()
# processing question:
# question_annotated = pd.DataFrame(question_annotated)
# exact match:
for pre_proc_answer, answer in zip(pre_proc_answers, answers):
if answer in spans:
exact_match_ind = spans.index(answer)
found_answers = found_answers.append({'span_index': exact_match_ind, 'answer': answer, 'span': answer},
ignore_index=True)
if pre_proc_answer in spans:
exact_match_ind = spans.index(pre_proc_answer)
found_answers = found_answers.append(
{'span_index': exact_match_ind, 'answer': answer, 'span': pre_proc_answer}, ignore_index=True)
# year should match year.
if question.find('year') > -1:
year_in_answer = re.search('([1-2][0-9]{3})', answer)
if year_in_answer is not None:
year_in_answer = year_in_answer.group(0)
year_spans = spans_series[spans_series == year_in_answer]
if len(year_spans) > 0:
found_answers = found_answers.append(
{'span_index': year_spans.index[0], 'answer': answer, 'span': year_in_answer}, ignore_index=True)
return found_answers.drop_duplicates()
def compute_P1(matched_answers, golden_answer_list, pred_answer):
P1 = 0
if len(matched_answers) > 0:
P1 = 100
return P1
def evaluate(dataset_df, predictions):
# please predict the full file
if len(dataset_df) != len(predictions):
print('predictions file does not match dataset file number of examples!!!')
P1 = 0
for prediction in predictions:
golden_answer_list = []
for answer in dataset_df.loc[prediction['ID'],'answers']:
golden_answer_list.append(answer['answer'])
golden_answer_list += answer['aliases']
if not None in golden_answer_list:
matched_answers = compare_span_to_answer([prediction['answer']], golden_answer_list,
dataset_df.loc[prediction['ID'], 'question'])
curr_P1 = compute_P1(matched_answers, golden_answer_list, prediction['answer'])
P1 += curr_P1
return P1/len(dataset_df)
if __name__ == '__main__':
#expected_version = '1.0'
parser = argparse.ArgumentParser(
description='Evaluation for ComplexWebQuestions ')
parser.add_argument('dataset_file', help='Dataset file')
parser.add_argument('prediction_file', help='Prediction File')
args = parser.parse_args()
with open(args.dataset_file) as dataset_file:
dataset_df = pd.DataFrame(json.load(dataset_file)).set_index('ID')
with open(args.prediction_file) as prediction_file:
predictions = json.load(prediction_file)
print(json.dumps(evaluate(dataset_df, predictions)))