From ce46b8e45c23f4fd8047debfd9e82f487b465661 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 9 Dec 2024 20:29:30 -0500 Subject: [PATCH] some typos --- yoyodyne/models/rnn.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/yoyodyne/models/rnn.py b/yoyodyne/models/rnn.py index 3dcb8bd..ad9708b 100644 --- a/yoyodyne/models/rnn.py +++ b/yoyodyne/models/rnn.py @@ -150,7 +150,7 @@ def beam_decode( ], device=self.device, ).unsqueeze(0) - # --> B x beam_size. + # -> B x beam_size. scores = torch.tensor([h[0] for h in histories]).unsqueeze(0) return predictions, scores @@ -253,16 +253,14 @@ def forward( # match the Tuple[torch.Tensor, torch.Tensor] type because the # training and validation functions depend on it. if self.beam_width > 1: - x = self.beam_decode(encoder_out, batch.source.mask) - return x + return self.beam_decode(encoder_out, batch.source.mask) else: - x = self.greedy_decode( + return self.greedy_decode( encoder_out, batch.source.mask, self.teacher_forcing if self.training else False, batch.target.padded if batch.target else None, ) - return x @staticmethod def add_argparse_args(parser: argparse.ArgumentParser) -> None: