diff --git a/yoyodyne/data/mappers.py b/yoyodyne/data/mappers.py index 49d4c468..c769c21f 100644 --- a/yoyodyne/data/mappers.py +++ b/yoyodyne/data/mappers.py @@ -89,20 +89,23 @@ def _decode( ) -> List[str]: """Decodes a tensor. + Decoding halts at END; other special symbols are omitted. + Args: indices (torch.Tensor): 1d tensor of indices. Yields: List[str]: Decoded symbols. """ - return [ - self.index.get_symbol(c) - for c in indices - if not special.isspecial(c) - ] - - # These are just here for compatibility but they all have - # the same implementation. + symbols = [] + for idx in indices: + if idx == special.END_IDX: + return symbols + elif not special.isspecial(idx): + symbols.append(self.index.get_symbol(idx)) + return symbols + + # These are here for compatibility; they all have the same implementation. def decode_source( self, diff --git a/yoyodyne/models/base.py b/yoyodyne/models/base.py index 909defd8..b3e40acd 100644 --- a/yoyodyne/models/base.py +++ b/yoyodyne/models/base.py @@ -297,12 +297,11 @@ def predict_step( using beam search, the predictions and scores as a tuple of tensors; if using greedy search, the predictions as a tensor. """ - predictions = self(batch) + if self.beam_width > 1: - predictions, scores = predictions - return predictions, scores + return self(batch) else: - return self._get_predicted(predictions) + return self._get_predicted(self(batch)) def _get_predicted(self, predictions: torch.Tensor) -> torch.Tensor: """Picks the best index from the vocabulary. diff --git a/yoyodyne/models/rnn.py b/yoyodyne/models/rnn.py index 3125ea4f..3dcb8bd8 100644 --- a/yoyodyne/models/rnn.py +++ b/yoyodyne/models/rnn.py @@ -142,17 +142,17 @@ def beam_decode( # Sometimes path lengths does not match so it is neccesary to pad it # all to same length to create a tensor. max_len = max(len(h[1]) for h in histories) + # -> B x beam_size x seq_len. predictions = torch.tensor( [ h[1] + [special.PAD_IDX] * (max_len - len(h[1])) for h in histories ], device=self.device, - ) - # Converts shape to that of `decode`: seq_len x B x target_vocab_size. - predictions = predictions.unsqueeze(0).transpose(0, 2) - # Beam search returns the likelihoods of each history. - return predictions, torch.tensor([h[0] for h in histories]) + ).unsqueeze(0) + # --> B x beam_size. + scores = torch.tensor([h[0] for h in histories]).unsqueeze(0) + return predictions, scores def greedy_decode( self, @@ -225,7 +225,8 @@ def greedy_decode( -1 ): break - predictions = torch.stack(predictions) + # -> B x seq_len x target_vocab_size. + predictions = torch.stack(predictions, dim=1) return predictions def forward( @@ -240,34 +241,28 @@ def forward( Returns: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: beam search returns a tuple with a tensor of predictions of shape - beam_width x seq_len and tensor with the unnormalized sum - of symbol log-probabilities for each prediction. Greedy returns + B x beam_width x seq_len and a tensor of shape beam_width with + the likelihood (the unnormalized sum of sequence + log-probabilities) for each prediction; greedy search returns a tensor of predictions of shape - seq_len x batch_size x target_vocab_size. + B x target_vocab_size x seq_len. """ encoder_out = self.source_encoder(batch.source).output - # Now this function has a polymorphic return because beam search needs - # to return two tensors. For greedy, the return has not been modified - # to match the Tuple[torch.Tensor, torch.Tensor] type because the + # This function has a polymorphic return because beam search needs to + # return two tensors. For greedy, the return has not been modified to + # match the Tuple[torch.Tensor, torch.Tensor] type because the # training and validation functions depend on it. if self.beam_width > 1: - predictions, scores = self.beam_decode( - encoder_out, - batch.source.mask, - ) - # Reduces to beam_width x seq_len - predictions = predictions.transpose(0, 2).squeeze(0) - return predictions, scores + x = self.beam_decode(encoder_out, batch.source.mask) + return x else: - predictions = self.greedy_decode( + x = self.greedy_decode( encoder_out, batch.source.mask, self.teacher_forcing if self.training else False, batch.target.padded if batch.target else None, ) - # -> B x seq_len x target_vocab_size. - predictions = predictions.transpose(0, 1) - return predictions + return x @staticmethod def add_argparse_args(parser: argparse.ArgumentParser) -> None: diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index 3af0eb30..c4b3602d 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -102,26 +102,28 @@ def predict( if model.beam_width > 1: # Beam search. tsv_writer = csv.writer(sink, delimiter="\t") - for predictions, scores in trainer.predict(model, loader): - predictions = util.pad_tensor_after_end(predictions) - # TODO: beam search requires singleton batches and this - # assumes that. Revise if that restriction is ever lifted. - targets = [ - parser.target_string(mapper.decode_target(target)) - for target in predictions - ] - # Collates target strings and their scores. - row = itertools.chain.from_iterable( - zip(targets, scores.tolist()) - ) - tsv_writer.writerow(row) + for batch_predictions, batch_scores in trainer.predict( + model, loader + ): + # Even though beam search currently assumes batch size of 1, + # this assumption is not baked-in here and should generalize + # if this restriction is lifted. + for beam, beam_scores in zip(batch_predictions, batch_scores): + beam_strings = [ + parser.target_string(mapper.decode_target(prediction)) + for prediction in beam + ] + # Collates target strings and their scores. + row = itertools.chain.from_iterable( + zip(beam_strings, beam_scores.tolist()) + ) + tsv_writer.writerow(row) else: # Greedy search. - for predictions in trainer.predict(model, loader): - predictions = util.pad_tensor_after_end(predictions) - for target in predictions: + for batch in trainer.predict(model, loader): + for prediction in batch: print( - parser.target_string(mapper.decode_target(target)), + parser.target_string(mapper.decode_target(prediction)), file=sink, )