Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
Merge pull request #10 from fairinternal/evaluation
Browse files Browse the repository at this point in the history
Evaluation
  • Loading branch information
fabiopetroni authored Apr 6, 2020
2 parents 603ac44 + 11092eb commit 9349a67
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 17 deletions.
91 changes: 81 additions & 10 deletions kilt/eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import argparse
import json

import jsonlines

import kilt.evaluation_metrics as metrics
from collections import Counter
import string
import re


def exact_match(guess_dataset, gold_dataset):
Expand Down Expand Up @@ -30,22 +36,85 @@ def average_prec_at_1(guess_dataset, gold_dataset):
total_prec_at_1 = 0.0
for guess, gold in zip(guess_dataset, gold_dataset):
predicted_page_ids = [provenance["wikipedia_id"]
for provenance in guess["output"][0]["provenance"]]
for provenance in guess["output"][0].get("provenance",[])]
total_prec_at_1 += metrics.precision_at_1(gold, predicted_page_ids)
count += 1
return total_prec_at_1 / count


def load_records(file_name):
records = []
with open(file_name) as file:
records = load_records_from_file(file)
return records
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)


def qa_exact_match(guess_dataset, gold_dataset):
total = 0
total_em = 0
for guess, gold in zip(guess_dataset, gold_dataset):
ground_truths = [item["answer"] for item in gold["output"]]
total_em += metric_max_over_ground_truths(
exact_match_score_qa, guess['output'][0]['answer'], ground_truths)
total += 1
return total_em / total


def qa_f1(guess_dataset, gold_dataset):
total = 0
total_em = 0
for guess, gold in zip(guess_dataset, gold_dataset):
ground_truths = [item["answer"] for item in gold["output"]]
total_em += metric_max_over_ground_truths(
f1_score, guess['output'][0]['answer'], ground_truths)
total += 1
return total_em / total


def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)

def white_space_fix(text):
return ' '.join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1


def exact_match_score_qa(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))


def load_records(jsonl_file_name):
with jsonlines.open(jsonl_file_name) as reader:
return load_records_from_jsonl_reader(reader)


def load_records_from_file(file):
def load_records_from_jsonl_reader(reader):
records = []
for line in file.readlines():
record = json.loads(line)
for record in reader:
records.append(record)
return records

Expand All @@ -55,7 +124,9 @@ def calculate_metrics(gold_records, guess_records):
assert gold["id"] == guess["id"], "Items must have same order with same IDs"
return {
"em": exact_match(guess_records, gold_records),
"prec_at_1": average_prec_at_1(guess_records, gold_records)
"prec_at_1": average_prec_at_1(guess_records, gold_records),
"em_qa": qa_exact_match(guess_records, gold_records),
"f1_qa": qa_f1(guess_records, gold_records),
}


Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ spacy
bs4
nltk
pymongo
pytest
pytest
jsonlines
3 changes: 2 additions & 1 deletion tests/test_data/gold_kilt.out
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
{"id": 1, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "suicide", "provenance": [{"wikipedia_id": "10287141", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]}
{"id": 2, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "suicide", "provenance": [{"wikipedia_id": "10287141", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]}
{"id": 2, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "suicide", "provenance": [{"wikipedia_id": "10287141", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]}
{"id": 3, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "it was suicide", "provenance": [{"wikipedia_id": "10287141", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]}
3 changes: 2 additions & 1 deletion tests/test_data/guess_kilt.out
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
{"id": 1, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "suicide", "provenance": [{"wikipedia_id": "10287141", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]}
{"id": 2, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "assasination", "provenance": [{"wikipedia_id": "999", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]}
{"id": 2, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "a suicide", "provenance": [{"wikipedia_id": "999", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]}
{"id": 3, "input": "In one word, how did Chen Yi-hsiung die?", "meta": {"wikidata_relation": "manner of death", "question_template": "In one word, how did XXX die?"}, "output": [{"answer": "suicide", "provenance": [{"wikipedia_id": "999", "title": "Chen Yi-hsiung", "start_paragraph_id": 5, "start_character": 190, "end_paragraph_id": 5, "end_character": 197, "bleu_score": 1.0, "meta": {}}]}]}
10 changes: 6 additions & 4 deletions tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
def test_calculate_metrics():
with importlib.resources.open_text(test_data, "gold_kilt.out") as gold_file:
with importlib.resources.open_text(test_data, "guess_kilt.out") as guess_file:
gold_records = kilt.eval.load_records_from_file(gold_file)
guess_records = kilt.eval.load_records_from_file(guess_file)
gold_records = kilt.eval.load_records(gold_file.name)
guess_records = kilt.eval.load_records(guess_file.name)
result = kilt.eval.calculate_metrics(gold_records, guess_records)
assert result["em"] == 0.5
assert result["prec_at_1"] == 0.5
assert result["em"] == 1/3
assert result["prec_at_1"] == 1/3
assert result["em_qa"] == 2/3
assert result["f1_qa"] == 0.8333333333333334

0 comments on commit 9349a67

Please sign in to comment.