diff --git a/yoyodyne/models/rnn.py b/yoyodyne/models/rnn.py index 3dcb8bd..ad9708b 100644 --- a/yoyodyne/models/rnn.py +++ b/yoyodyne/models/rnn.py @@ -150,7 +150,7 @@ def beam_decode( ], device=self.device, ).unsqueeze(0) - # --> B x beam_size. + # -> B x beam_size. scores = torch.tensor([h[0] for h in histories]).unsqueeze(0) return predictions, scores @@ -253,16 +253,14 @@ def forward( # match the Tuple[torch.Tensor, torch.Tensor] type because the # training and validation functions depend on it. if self.beam_width > 1: - x = self.beam_decode(encoder_out, batch.source.mask) - return x + return self.beam_decode(encoder_out, batch.source.mask) else: - x = self.greedy_decode( + return self.greedy_decode( encoder_out, batch.source.mask, self.teacher_forcing if self.training else False, batch.target.padded if batch.target else None, ) - return x @staticmethod def add_argparse_args(parser: argparse.ArgumentParser) -> None: