-
Notifications
You must be signed in to change notification settings - Fork 1
/
beam_search_from_pointer_generator.py
150 lines (125 loc) · 6.03 KB
/
beam_search_from_pointer_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Modifications Copyright 2017 Abigail See
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This file contains code to run beam search decoding"""
BEAM_SIZE = 10
SOS_ID = 2274
EOS_ID = 2275
UNK_ID = 2276
MAX_DEC_LEN = 60
MIN_DEC_LEN = 1
VOCAB_SIZE = 2276
class Hypothesis(object):
"""Class to represent a hypothesis during beam search. Holds all the information needed for the hypothesis."""
def __init__(self, tokens, log_probs, state):
"""Hypothesis constructor.
Args:
tokens: List of integers. The ids of the tokens that form the summary so far.
log_probs: List, same length as tokens, of floats, giving the log probabilities of the tokens so far.
state: Current state of the decoder, a LSTMStateTuple.
"""
self.tokens = tokens
self.log_probs = log_probs
self.state = state
def extend(self, token, log_prob, state):
"""Return a NEW hypothesis, extended with the information from the latest step of beam search.
Args:
token: Integer. Latest token produced by beam search.
log_prob: Float. Log prob of the latest token.
state: Current decoder state, a LSTMStateTuple.
Returns:
New Hypothesis for next step.
"""
return Hypothesis(tokens = self.tokens + [token],
log_probs = self.log_probs + [log_prob],
state = state
)
@property
def latest_token(self):
return self.tokens[-1]
@property
def log_prob(self):
# the log probability of the hypothesis so far is the sum of the log probabilities of the tokens so far
return sum(self.log_probs)
@property
def avg_log_prob(self):
# normalize log probability by number of tokens (otherwise longer sequences always have lower probability)
return self.log_prob / len(self.tokens)
def run_beam_search(sess, model, vocab, batch):
"""Performs beam search decoding on the given example.
Args:
sess: a tf.Session
model: a seq2seq model
vocab: Vocabulary object
batch: Batch object that is the same example repeated across the batch
Returns:
best_hyp: Hypothesis object; the best hypothesis found by beam search.
"""
# Run the encoder to get the encoder hidden states and decoder initial state
enc_states, dec_in_state = model.run_encoder(sess, batch)
# dec_in_state is a LSTMStateTuple
# enc_states has shape [batch_size, <=max_enc_steps, 2*hidden_dim].
# Initialize beam_size-many hyptheses
hyps = [Hypothesis(tokens=[SOS_ID],
log_probs=[0.0],
state=dec_in_state
) for _ in range(BEAM_SIZE)]
results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token)
steps = 0
while steps < MAX_DEC_LEN and len(results) < BEAM_SIZE:
latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis
latest_tokens = [t if t in range(VOCAB_SIZE) else UNK_ID for t in latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings
states = [h.state for h in hyps] # list of current decoder states of the hypotheses
# Run one step of the decoder to get the new info
(topk_ids, topk_log_probs, new_states) = model.decode_onestep(sess=sess,
batch=batch,
latest_tokens=latest_tokens,
enc_states=enc_states,
dec_init_states=states)
# Extend each hypothesis and collect them all in all_hyps
all_hyps = []
num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
for i in range(num_orig_hyps):
h, new_state = hyps[i], new_states[i] # take the ith hypothesis and new decoder state info
for j in range(BEAM_SIZE * 2): # for each of the top 2*beam_size hyps:
# Extend the ith hypothesis with the jth option
new_hyp = h.extend(token=topk_ids[i, j],
log_prob=topk_log_probs[i, j],
state=new_state
)
all_hyps.append(new_hyp)
# Filter and collect any hypotheses that have produced the end token.
hyps = [] # will contain hypotheses for the next step
for h in sort_hyps(all_hyps): # in order of most likely h
if h.latest_token == EOS_ID: # if stop token is reached...
# If this hypothesis is sufficiently long, put in results. Otherwise discard.
if steps >= MIN_DEC_LEN:
results.append(h)
else: # hasn't reached stop token, so continue to extend this hypothesis
hyps.append(h)
if len(hyps) == BEAM_SIZE or len(results) == BEAM_SIZE:
# Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop.
break
steps += 1
# At this point, either we've got beam_size results, or we've reached maximum decoder steps
if len(results)==0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
results = hyps
# Sort hypotheses by average log probability
hyps_sorted = sort_hyps(results)
# Return the hypothesis with highest average log prob
return hyps_sorted[0]
def sort_hyps(hyps):
"""Return a list of Hypothesis objects, sorted by descending average log probability"""
return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True)