-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_recall_set_selection.py
executable file
·155 lines (126 loc) · 5.26 KB
/
create_recall_set_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import glob
import sys
from tqdm import tqdm
import pickle
import argparse
import numpy as np
import math
from nltk.util import ngrams
from utils import *
import sacrebleu
from sacrebleu.metrics.helpers import extract_all_word_ngrams
from sacrebleu.tokenizers import tokenizer_13a
import re
from collections import Counter, defaultdict
from nltk.util import ngrams
from typing import List, Tuple
max_ngram_order = 4
tok_13a = tokenizer_13a.Tokenizer13a()
def my_log(num: float) -> float:
if num == 0.0:
return -9999999999
return math.log(num)
def _preprocess(sent, ignore_whitespace):
sent = sent.rstrip()
if ignore_whitespace:
sent = re.sub(r"\s+", "", sent)
else:
sent = tok_13a(sent)
return sent
def _compute_score_from_stats(correct, total, effective_order=True):
scores = [0.0 for x in range(max_ngram_order)]
smooth_mteval = 1.
eff_order = max_ngram_order
if not any(correct):
return 0.0
for n in range(1, len(scores) + 1):
if total[n - 1] == 0:
break
if effective_order:
eff_order = n
if correct[n - 1] == 0:
smooth_mteval *= 2
scores[n - 1] = 100. / (smooth_mteval * total[n - 1])
else:
scores[n - 1] = 100. * correct[n - 1] / total[n - 1]
score = math.exp(
sum([my_log(p) for p in scores[:eff_order]]) / eff_order)
return score
def get_ngram_overlap_count(ref_ngrams, hyp_ngrams):
correct = [0 for i in range(max_ngram_order)]
total = correct[:]
for hyp_ngram, hyp_count in hyp_ngrams.items():
n = len(hyp_ngram) - 1
total[n] += hyp_count
if hyp_ngram in ref_ngrams:
correct[n] += min(hyp_count, ref_ngrams[hyp_ngram])
return _compute_score_from_stats(correct, total)
def get_recall_overlap_score(hyp_ngrams, ref_ngrams, beta=3.0, epsilon=1e-16):
overlap_ngrams = ref_ngrams & hyp_ngrams
tp = sum(overlap_ngrams.values()) # True positives.
tpfp = sum(hyp_ngrams.values()) # True positives + False positives.
tpfn = sum(ref_ngrams.values()) # True positives + False negatives.
try:
prec = tp / tpfp # precision
rec = tp / tpfn # recall
factor = beta ** 2
fscore = (1 + factor) * (prec * rec) / (factor * prec + rec)
except ZeroDivisionError:
prec = rec = fscore = epsilon
return rec
def select_prompt_set(source, prompts, weight = 0.1, ignore_whitespace=False, min_bleu_threshold=1):
ref_ngrams, ref_len = extract_all_word_ngrams(_preprocess(source, ignore_whitespace), 1, max_ngram_order)
hyp_ngrams_list = {}
for i, pr_src in enumerate(prompts):
hyp_ngrams_list[i] = extract_all_word_ngrams(_preprocess(pr_src, ignore_whitespace), 1, max_ngram_order)[0]
# print(ref_ngrams, hyp_ngrams_list[i])
is_continue = True
selected_prompts = []
while(is_continue):
overlap_scores = []
for i in hyp_ngrams_list:
overlap_score = get_ngram_overlap_count(hyp_ngrams_list[i], ref_ngrams)
overlap_scores.append(overlap_score)
top_1 = np.argmax(overlap_scores)
if overlap_scores[top_1] < min_bleu_threshold:
break
if top_1 not in selected_prompts:
selected_prompts.append(top_1)
# find intersecting ngrams
hyp_top_1 = hyp_ngrams_list[top_1]
intersect = ref_ngrams & hyp_top_1
# downweight found ngrams
for ngram, k in ref_ngrams.items():
if ngram in hyp_top_1:
ref_ngrams[ngram] *= weight
# set ngrams of top_1 to 0
for ngram, k in hyp_ngrams_list[top_1].items():
hyp_ngrams_list[top_1][ngram] = 0.0
return selected_prompts
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--domain", type=str, required=True)
parser.add_argument("--input-prompt-file", type=str, required=True)
parser.add_argument("--input-source-file", type=str, required=True)
parser.add_argument("--weight", type=float, default=0.1)
parser.add_argument("--output-prompt-file", type=str, required=True)
parser.add_argument("--ignore-whitespace", action='store_true')
parser.add_argument("--split", type=str, default="test")
parser.add_argument("--min-bleu-threshold", type=float, default=1.0)
args = parser.parse_args()
eval_samples = read_file(f"{args.input_source_file}")
with open(args.input_prompt_file, "rb") as f:
pool_prompts = pickle.load(f)
prompts = {}
number_of_prompts = []
for i, source in tqdm(enumerate(eval_samples)):
prompt_src = [pr.data["src"] for pr in pool_prompts[i]]
selected_indices = select_prompt_set(source, prompt_src, weight=args.weight, ignore_whitespace=args.ignore_whitespace, min_bleu_threshold=args.min_bleu_threshold)
prompts[i] = [pool_prompts[i][j] for j in selected_indices]
number_of_prompts.append(len(selected_indices))
print("Maximum Number of prompts", max(number_of_prompts))
with open(args.output_prompt_file, "wb") as f:
pickle.dump(prompts, f)
if __name__ == '__main__':
main()