Skip to content

Commit

Permalink
Creating a data Structure for the Metric System
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrojlazevedo committed Mar 15, 2020
1 parent d76923a commit 8e262ff
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 70 deletions.
Empty file added data/dev_relevant_docs.jsonl
Empty file.
18 changes: 9 additions & 9 deletions generate_rte_preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
instances.append(line)


def createTestSet(claim, candidateEvidences, claim_num):
def create_test_set(claim, candidateEvidences, claim_num):
testset = []
for elem in candidateEvidences:
testset.append({"hypothesis": claim, "premise": elem})
Expand All @@ -39,8 +39,8 @@ def createTestSet(claim, candidateEvidences, claim_num):

def run_rte(claim, evidence, claim_num):
fname = "claim_" + str(claim_num) + ".json"
testset = createTestSet(claim, evidence, claim_num)
preds = predictor.predict_batch_json(testset)
test_set = create_test_set(claim, evidence, claim_num)
preds = predictor.predict_batch_json(test_set)
return preds


Expand All @@ -50,9 +50,8 @@ def run_rte(claim, evidence, claim_num):
evidence = instances[i]['predicted_sentences']
potential_evidence_sentences = []
for sentence in evidence:
#print(sentence)
#print(sentence[0])

# print(sentence)
# print(sentence[0])
# load document from TF-IDF
relevant_doc = ud.normalize('NFC', sentence[0])
relevant_doc = relevant_doc.replace("/", "-SLH-")
Expand All @@ -74,13 +73,14 @@ def run_rte(claim, evidence, claim_num):
if len(potential_evidence_sentences) == 0:
zero_results += 1
potential_evidence_sentences.append("Nothing")
evidence.append(["Nothing", 0])

preds = run_rte(claim, potential_evidence_sentences, claim_num)

saveFile = codecs.open("rte/entailment_predictions/claim_" + str(claim_num) + ".json", mode="w+", encoding="utf-8")
for i in range(len(preds)):
#print(preds)
#print(evidence)
# print(preds)
# print(evidence)
preds[i]['claim'] = claim
preds[i]['premise_source_doc_id'] = evidence[i][0]
preds[i]['premise_source_doc_line_num'] = evidence[i][1]
Expand All @@ -91,4 +91,4 @@ def run_rte(claim, evidence, claim_num):
claim_num += 1
print(claim_num)

print("Number of Zero Sentences Found: " + str(zero_results))
print("Number of Zero Sentences Found: " + str(zero_results))
86 changes: 28 additions & 58 deletions metrics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import jsonlines
import sys
from scorer import fever_score
from metrics.claim import Claim

claims = []

train_file = "data/subsample_train.jsonl"
train_relevant_file = "data/subsample_train_relevant_docs.jsonl"
train_concatenate_file = "data/subsample_train_concatenation.jsonl"
train_predictions_file = "predictions/predictions_train.jsonl"

# loading for dev
train_file = "data/dev.jsonl"
train_relevant_file = "data/dev_relevant_docs.jsonl"
train_concatenate_file = "data/subsample_train_concatenation.jsonl"
train_predictions_file = "predictions/new_predictions_dev.jsonl"

train_file = jsonlines.open(train_file)
train_relevant_file = jsonlines.open(train_relevant_file)
train_concatenate_file = jsonlines.open(train_concatenate_file)
Expand Down Expand Up @@ -35,6 +44,7 @@
# this evidence addition is irrelevant
info_by_id = dict((d['id'], dict(d, index=index)) for (index, d) in enumerate(train_set))
for lines in train_predictions_file:
#print(lines['id'])
lines['evidence'] = info_by_id.get(lines['id'])['evidence']
train_prediction.append(lines)

Expand All @@ -53,50 +63,10 @@
gold_data = []

for claim in train_set:

# init gold dict
gold_dict = {'id': claim['id']}

if claim['verifiable'] == "VERIFIABLE":
gold_dict['verifiable'] = 1
else:
gold_dict['verifiable'] = 0

# get gold inputs
gold_documents = set()
gold_documents_separated = set()
sentences_pair = set()
evidences = claim['evidence']
difficulties = []
for evidence in evidences:
doc_name = ''
difficulty = 0
if len(evidence) > 1: # needs more than 1 doc to be verifiable
for e in evidence:
doc_name += str(e[2])
doc_name += " "
sentences_pair.add((str(e[2]), str(e[3]))) # add gold sentences
gold_documents_separated.add(str(e[2])) # add the document
difficulty += 1
doc_name = doc_name[:-1] # erase the last blank space
else:
doc_name = str(evidence[0][2])
gold_documents_separated.add(str(evidence[0][2]))
sentences_pair.add((str(evidence[0][2]), str(evidence[0][3])))
difficulty = 1
difficulties.append(difficulty)
gold_documents.add(doc_name)
gold_dict['difficulties'] = difficulties
gold_dict['docs'] = gold_documents
gold_dict['evidences'] = sentences_pair
gold_dict['docs_sep'] = gold_documents_separated

gold_data.append(gold_dict)

# flag to stop if needed
stop += 1
if stop == -1:
break
_claim = Claim(claim['id'], claim['claim'], claim['verifiable'])
_claim.add_gold_evidences(claim['evidence'])
claims.append(_claim)
# print(_claim.get_gold_documents())

gold_data = dict((item['id'], item) for item in gold_data)

Expand Down Expand Up @@ -144,8 +114,8 @@
doc_incorrect += 1
docs.add(doc)

precision_correct += doc_correct / len(docs)
precision_incorrect += doc_incorrect / len(docs)
precision_correct += doc_correct / (len(docs) + 0.0001)
precision_incorrect += doc_incorrect / (len(docs) + 0.0001)
recall_correct += doc_correct / len(gold_docs)
recall_incorrect += doc_incorrect / len(gold_docs)

Expand Down Expand Up @@ -180,8 +150,8 @@
if doc_correct and flag:
sent_found_if_doc_found += 1

precision_sent_correct += sent_correct / len(sentences)
precision_sent_incorrect += sent_incorrect / len(sentences)
precision_sent_correct += sent_correct / (len(sentences) + 0.00001)
precision_sent_incorrect += sent_incorrect / (len(sentences) + 0.00001)
recall_sent_correct += sent_correct / len(evidences)
recall_sent_incorrect += sent_incorrect / len(evidences)

Expand All @@ -204,11 +174,11 @@
print("\n#############")
print("# DOCUMENTS #")
print("#############")
print("Precision (Document Retrieved):\t\t\t\t\t\t " + str(precision_correct)) # precision
print("Fall-out (incorrect documents):\t\t\t\t\t\t " + str(precision_incorrect)) # precision
print("Recall (Relevant Documents):\t\t\t\t\t\t " + str(recall_correct)) # recall
print("Percentage of gold documents NOT found:\t\t\t\t " + str(recall_incorrect)) # recall
print("Fall-out: " + str(specificity))
print("Precision (Document Retrieved):\t\t\t " + str(precision_correct)) # precision
print("Fall-out (incorrect documents):\t\t\t " + str(precision_incorrect)) # precision
print("Recall (Relevant Documents):\t\t\t " + str(recall_correct)) # recall
print("Percentage of gold documents NOT found:\t\t " + str(recall_incorrect)) # recall
print("Fall-out:\t\t\t\t\t " + str(specificity))
print("Percentage of at least one document found correctly: " + str(doc_found)) # recall

precision_sent_correct /= total_claim
Expand All @@ -222,10 +192,10 @@
print("\n#############")
print("# SENTENCES #")
print("#############")
print("Precision (Sentences Retrieved):\t\t\t\t\t " + str(precision_sent_correct)) # precision
print("Precision (incorrect Sentences):\t\t\t\t\t " + str(precision_sent_incorrect)) # precision
print("Recall (Relevant Sentences):\t\t\t\t\t\t " + str(recall_sent_correct)) # recall
print("Percentage of gold Sentences NOT found:\t\t\t\t " + str(recall_sent_incorrect)) # recall
print("Precision (Sentences Retrieved):\t\t\t " + str(precision_sent_correct)) # precision
print("Precision (incorrect Sentences):\t\t\t " + str(precision_sent_incorrect)) # precision
print("Recall (Relevant Sentences):\t\t\t\t " + str(recall_sent_correct)) # recall
print("Percentage of gold Sentences NOT found:\t\t " + str(recall_sent_incorrect)) # recall
print("Percentage of at least one Sentence found correctly: " + str(sent_found)) # recall
print("Percentage of at least one Sentence found correctly: " + str(sent_found_if_doc_found)) # recall
print("Percentage of at least one Sentence found correctly: " + str(another_sent)) # recall
Expand All @@ -236,7 +206,7 @@
print("\n#########")
print("# FEVER #")
print("#########")
print("Strict_score: \t\t" + str(results[0]))
print("Strict_score: \t\t\t" + str(results[0]))
print("Acc_score: \t\t\t" + str(results[1]))
print("Precision: \t\t\t" + str(results[2]))
print("Recall: \t\t\t" + str(results[3]))
Expand Down
Empty file added metrics/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions metrics/claim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from metrics.evidence import Evidence


class Claim:

def __init__(self, _id, name, verifiable):
self.id = _id
self.name = name
if verifiable == "VERIFIABLE":
self.verifiable = 1
else:
self.verifiable = 0
self.gold_evidence = []

def add_gold_evidence(self, document, evidence, line_num):
evidence = Evidence(document, evidence, line_num)
self.gold_evidence.append(evidence)

def add_gold_evidences(self, evidences):
for evidence in evidences:
_evidence = Evidence()
if len(evidence) > 1: # needs more than 1 doc to be verifiable
for e in evidence:
_evidence.add_pair(str(e[2]), str(e[3]))
else:
_evidence.add_pair(str(evidence[0][2]), str(evidence[0][3]))
self.gold_evidence.append(_evidence)

def get_gold_documents(self):
docs = set()
for e in self.gold_evidence:
docs |= e.documents
return docs
32 changes: 32 additions & 0 deletions metrics/evidence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
class Evidence:

def __init__(self, document="", sentence="", line_num=0):
self.documents = set()
self.sentences = set()
self.pairs = set()
# if arguments are passed
if document != "":
self.documents.add(document)
if sentence != "":
self.sentences.add(sentence)
if line_num != 0:
self.pairs.add((document, line_num))

def add_document(self, doc):
self.documents.add(doc)

def add_sentence(self, sentence):
self.sentences.add(sentence)

def add_pair(self, doc, line_num):
self.pairs.add((doc, line_num))
self.add_document(doc)

def get_difficulty_documents(self):
return len(self.documents)

def get_difficulty_sentences(self):
return len(self.sentences)

def get_difficulty(self):
return len(self.pairs)
7 changes: 4 additions & 3 deletions train_label_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,13 @@ def predict_test(predictions_test, entailment_predictions_test, new_predictions_


predictions_train = "predictions/predictions_train.jsonl"
predictions_test = "predictions/predictions.jsonl"
new_predictions_file = "predictions/new_predictions.jsonl"

gold_train = "data/subsample_train_relevant_docs.jsonl"
entailment_predictions_train = "rte/entailment_predictions_train"
entailment_predictions_test = "rte/entailment_predictions_test"

predictions_test = "data/dev.jsonl"
entailment_predictions_test = "rte/entailment_predictions"
new_predictions_file = "predictions/new_predictions_dev.jsonl"

x_train, y_train = populate_train(gold_train, entailment_predictions_train)
# x_test = x_train[7000:]
Expand Down

0 comments on commit 8e262ff

Please sign in to comment.