-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
executable file
·120 lines (102 loc) · 3.62 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import csv
import functools
import hashlib
import itertools
import json
import os
import random
import re
import string
import sys
from typing import Any, Dict, List, Set, Tuple
import unicodedata
table = str.maketrans(dict.fromkeys(string.punctuation))
FIELDSEP = "|"
def makeID(text: str) -> str:
"""
Create a unique ID based on the value of the input text.
WARNING: This is typically used to create prompt IDs, but
because of issues with stray spaces in the prompts,
this may not always produce the ID you are expecting.
"""
textID = hashlib.md5(text.lower().encode()).hexdigest()
return f"prompt_{textID}"
def read_file(filename):
data = []
with open(filename, encoding="utf-8") as f:
for line in f:
data.append(line.strip())
return data
def write_to_file(d, id_text, filename):
with open(filename, "w", encoding="utf-8") as out:
for prompt in d:
src = id_text[prompt]
out.write(f"\n{prompt}{FIELDSEP}{src}\n")
for para, v in d[prompt].items():
out.write(para + "\n")
def read_trans_prompts(lines: List[str]) -> List[Tuple[str,str]]:
"""
This reads a file in the shared task format, returns a list of Tuples containing ID and text for each prompt.
"""
ids_prompts = []
first = True
for line in lines:
line = line.strip().lower()
# in a group, the first one is the KEY.
# all others are part of the set.
if len(line) == 0:
first = True
else:
if first:
key, prompt = line.split(FIELDSEP)
ids_prompts.append((key, prompt))
first = False
return ids_prompts
def strip_punctuation(text: str) -> str:
"""
Remove punctuations of several languages, including Japanese.
"""
return "".join(
itertools.filterfalse(lambda x: unicodedata.category(x).startswith("P"), text)
)
def read_transfile(lines: List[str], strip_punc=True, weighted=False) -> Dict[str, Dict[str, float]]:
"""
This reads a file in the shared task format, and returns a dictionary with prompt IDs as
keys, and each key associated with a dictionary of responses.
"""
data = {}
first = True
options = {}
key = ""
for line in lines:
line = line.strip().lower()
# in a group, the first one is the KEY.
# all others are part of the set.
if len(line) == 0:
first = True
if len(key) > 0 and len(options) > 0:
if key in data:
print(f"Warning: duplicate sentence! {key}")
data[key] = options
options= {}
else:
if first:
key, _ = line.strip().split(FIELDSEP)
first = False
else:
# allow that a line may have a number at the end specifying the weight that this element should take.
# this is controlled by the weighted argument.
# gold is REQUIRED to have this weight.
if weighted:
# get text
text, weight = line.strip().split(FIELDSEP)
else:
text = line.strip()
weight = 1
if strip_punc:
text = text.translate(table)
options[text] = float(weight)
# check if there is still an element at the end.
if len(options) > 0:
data[key] = options
return data