-
Notifications
You must be signed in to change notification settings - Fork 2
/
pointer_funcs.py
84 lines (75 loc) · 2.8 KB
/
pointer_funcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from itertools import zip_longest
import difflib
def get_closest_match(token, options, backup_options):
additions = 10000
best_replacement = None
backward = False
if len(options) > 0 and options[0].startswith(token):
best_replacement = options[0]
if best_replacement == None:
for option in options:
addition_distance = len(option.replace(token, "", 1))
if token in option and addition_distance < additions:
best_replacement = option
additions = addition_distance
if best_replacement == None:
for option in backup_options:
addition_distance = len(option.replace(token, "", 1))
if token in option and addition_distance < additions:
best_replacement = option
additions = addition_distance
backward = True
return best_replacement, backward
def pointer_process(source_seq, target_seq):
def format_pt(pos):
return "<pt-{}>".format(pos)
outs = []
target_seq_out = []
word_to_pos = dict()
target_tokens = target_seq.strip().split()
prev_offset = 0
base_offset = 0
for token in target_tokens:
token = token.strip()
if token.startswith("[") or token.startswith("]"):
token_out = token
else:
matched_token, backward = get_closest_match(
token,
source_seq[prev_offset:].strip().split(),
source_seq.strip().split(),
)
if backward:
offset = source_seq.find(token)
else:
offset = prev_offset + source_seq[prev_offset:].find(token)
source_seq = (
source_seq[:offset]
+ " "
+ source_seq[offset : offset + len(token)]
+ " "
+ source_seq[offset + len(token) :]
)
base_offset += 2
prev_offset = offset + len(token) + 2
word_to_pos = dict()
for position, token in enumerate(source_seq.strip().split()):
token = token.strip()
if token in word_to_pos:
word_to_pos[token].append(position)
else:
word_to_pos[token] = [position]
for token in target_tokens:
token = token.strip()
if token in word_to_pos:
if token.startswith("[") or token.startswith("]"):
print(source_seq, target_seq, token)
sys.exit()
if len(word_to_pos[token]) == 1:
token_out = format_pt(word_to_pos[token][0])
else:
token_out = format_pt(word_to_pos[token].pop(0))
else:
token_out = token
target_seq_out.append(token_out)
return " ".join(source_seq.strip().split()), " ".join(target_seq_out)