Skip to content

Commit

Permalink
some typos
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebgorman committed Dec 10, 2024
1 parent ffe6095 commit ce46b8e
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions yoyodyne/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def beam_decode(
],
device=self.device,
).unsqueeze(0)
# --> B x beam_size.
# -> B x beam_size.
scores = torch.tensor([h[0] for h in histories]).unsqueeze(0)
return predictions, scores

Expand Down Expand Up @@ -253,16 +253,14 @@ def forward(
# match the Tuple[torch.Tensor, torch.Tensor] type because the
# training and validation functions depend on it.
if self.beam_width > 1:
x = self.beam_decode(encoder_out, batch.source.mask)
return x
return self.beam_decode(encoder_out, batch.source.mask)
else:
x = self.greedy_decode(
return self.greedy_decode(
encoder_out,
batch.source.mask,
self.teacher_forcing if self.training else False,
batch.target.padded if batch.target else None,
)
return x

@staticmethod
def add_argparse_args(parser: argparse.ArgumentParser) -> None:
Expand Down

0 comments on commit ce46b8e

Please sign in to comment.