Skip to content

Commit

Permalink
Merge pull request #290 from kylebgorman/batch
Browse files Browse the repository at this point in the history
Generalizes predictions for batches > 1
  • Loading branch information
kylebgorman authored Dec 10, 2024
2 parents 749a9ca + ce46b8e commit 1babacb
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 52 deletions.
19 changes: 11 additions & 8 deletions yoyodyne/data/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,23 @@ def _decode(
) -> List[str]:
"""Decodes a tensor.
Decoding halts at END; other special symbols are omitted.
Args:
indices (torch.Tensor): 1d tensor of indices.
Yields:
List[str]: Decoded symbols.
"""
return [
self.index.get_symbol(c)
for c in indices
if not special.isspecial(c)
]

# These are just here for compatibility but they all have
# the same implementation.
symbols = []
for idx in indices:
if idx == special.END_IDX:
return symbols
elif not special.isspecial(idx):
symbols.append(self.index.get_symbol(idx))
return symbols

# These are here for compatibility; they all have the same implementation.

def decode_source(
self,
Expand Down
7 changes: 3 additions & 4 deletions yoyodyne/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,11 @@ def predict_step(
using beam search, the predictions and scores as a tuple of
tensors; if using greedy search, the predictions as a tensor.
"""
predictions = self(batch)

if self.beam_width > 1:
predictions, scores = predictions
return predictions, scores
return self(batch)
else:
return self._get_predicted(predictions)
return self._get_predicted(self(batch))

def _get_predicted(self, predictions: torch.Tensor) -> torch.Tensor:
"""Picks the best index from the vocabulary.
Expand Down
39 changes: 16 additions & 23 deletions yoyodyne/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,17 @@ def beam_decode(
# Sometimes path lengths does not match so it is neccesary to pad it
# all to same length to create a tensor.
max_len = max(len(h[1]) for h in histories)
# -> B x beam_size x seq_len.
predictions = torch.tensor(
[
h[1] + [special.PAD_IDX] * (max_len - len(h[1]))
for h in histories
],
device=self.device,
)
# Converts shape to that of `decode`: seq_len x B x target_vocab_size.
predictions = predictions.unsqueeze(0).transpose(0, 2)
# Beam search returns the likelihoods of each history.
return predictions, torch.tensor([h[0] for h in histories])
).unsqueeze(0)
# -> B x beam_size.
scores = torch.tensor([h[0] for h in histories]).unsqueeze(0)
return predictions, scores

def greedy_decode(
self,
Expand Down Expand Up @@ -225,7 +225,8 @@ def greedy_decode(
-1
):
break
predictions = torch.stack(predictions)
# -> B x seq_len x target_vocab_size.
predictions = torch.stack(predictions, dim=1)
return predictions

def forward(
Expand All @@ -240,34 +241,26 @@ def forward(
Returns:
Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: beam
search returns a tuple with a tensor of predictions of shape
beam_width x seq_len and tensor with the unnormalized sum
of symbol log-probabilities for each prediction. Greedy returns
B x beam_width x seq_len and a tensor of shape beam_width with
the likelihood (the unnormalized sum of sequence
log-probabilities) for each prediction; greedy search returns
a tensor of predictions of shape
seq_len x batch_size x target_vocab_size.
B x target_vocab_size x seq_len.
"""
encoder_out = self.source_encoder(batch.source).output
# Now this function has a polymorphic return because beam search needs
# to return two tensors. For greedy, the return has not been modified
# to match the Tuple[torch.Tensor, torch.Tensor] type because the
# This function has a polymorphic return because beam search needs to
# return two tensors. For greedy, the return has not been modified to
# match the Tuple[torch.Tensor, torch.Tensor] type because the
# training and validation functions depend on it.
if self.beam_width > 1:
predictions, scores = self.beam_decode(
encoder_out,
batch.source.mask,
)
# Reduces to beam_width x seq_len
predictions = predictions.transpose(0, 2).squeeze(0)
return predictions, scores
return self.beam_decode(encoder_out, batch.source.mask)
else:
predictions = self.greedy_decode(
return self.greedy_decode(
encoder_out,
batch.source.mask,
self.teacher_forcing if self.training else False,
batch.target.padded if batch.target else None,
)
# -> B x seq_len x target_vocab_size.
predictions = predictions.transpose(0, 1)
return predictions

@staticmethod
def add_argparse_args(parser: argparse.ArgumentParser) -> None:
Expand Down
36 changes: 19 additions & 17 deletions yoyodyne/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,26 +102,28 @@ def predict(
if model.beam_width > 1:
# Beam search.
tsv_writer = csv.writer(sink, delimiter="\t")
for predictions, scores in trainer.predict(model, loader):
predictions = util.pad_tensor_after_end(predictions)
# TODO: beam search requires singleton batches and this
# assumes that. Revise if that restriction is ever lifted.
targets = [
parser.target_string(mapper.decode_target(target))
for target in predictions
]
# Collates target strings and their scores.
row = itertools.chain.from_iterable(
zip(targets, scores.tolist())
)
tsv_writer.writerow(row)
for batch_predictions, batch_scores in trainer.predict(
model, loader
):
# Even though beam search currently assumes batch size of 1,
# this assumption is not baked-in here and should generalize
# if this restriction is lifted.
for beam, beam_scores in zip(batch_predictions, batch_scores):
beam_strings = [
parser.target_string(mapper.decode_target(prediction))
for prediction in beam
]
# Collates target strings and their scores.
row = itertools.chain.from_iterable(
zip(beam_strings, beam_scores.tolist())
)
tsv_writer.writerow(row)
else:
# Greedy search.
for predictions in trainer.predict(model, loader):
predictions = util.pad_tensor_after_end(predictions)
for target in predictions:
for batch in trainer.predict(model, loader):
for prediction in batch:
print(
parser.target_string(mapper.decode_target(target)),
parser.target_string(mapper.decode_target(prediction)),
file=sink,
)

Expand Down

0 comments on commit 1babacb

Please sign in to comment.