From b79404c9767d04aa842c03329d72e34bdf56d8ee Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 9 Dec 2024 18:31:30 -0500 Subject: [PATCH] Adds helper for sizing the decoder input size. This eliminates a repeated ternary in the definition of the hard attention classes simply by making it a property of the ABC. --- yoyodyne/models/hard_attention.py | 46 +++++++++++++++---------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/yoyodyne/models/hard_attention.py b/yoyodyne/models/hard_attention.py index 6459713..9ed7f83 100644 --- a/yoyodyne/models/hard_attention.py +++ b/yoyodyne/models/hard_attention.py @@ -62,6 +62,20 @@ def __init__( self.teacher_forcing ), "Teacher forcing disabled but required by this model" + # Properties + + @property + def decoder_input_size(self) -> int: + if self.has_features_encoder: + return ( + self.source_encoder.output_size + + self.features_encoder.output_size + ) + else: + return self.source_encoder.output_size + + # Implemented interface. + def init_decoding( self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -394,6 +408,8 @@ def _loss( loss = -torch.logsumexp(fwd, dim=-1).mean() / target.size(1) return loss + # Interface. + def get_decoder(self): raise NotImplementedError @@ -401,6 +417,8 @@ def get_decoder(self): def name(self) -> str: raise NotImplementedError + # Flags. + @staticmethod def add_argparse_args(parser: argparse.ArgumentParser) -> None: """Adds HMM configuration options to the argument parser. @@ -437,12 +455,7 @@ def get_decoder(self): return modules.ContextHardAttentionGRUDecoder( attention_context=self.attention_context, bidirectional=False, - decoder_input_size=( - self.source_encoder.output_size - + self.features_encoder.output_size - if self.has_features_encoder - else self.source_encoder.output_size - ), + decoder_input_size=self.decoder_input_size, dropout=self.dropout, embeddings=self.embeddings, embedding_size=self.embedding_size, @@ -453,12 +466,7 @@ def get_decoder(self): else: return modules.HardAttentionGRUDecoder( bidirectional=False, - decoder_input_size=( - self.source_encoder.output_size - + self.features_encoder.output_size - if self.has_features_encoder - else self.source_encoder.output_size - ), + decoder_input_size=self.decoder_input_size, dropout=self.dropout, embedding_size=self.embedding_size, embeddings=self.embeddings, @@ -523,12 +531,7 @@ def get_decoder(self): return modules.ContextHardAttentionLSTMDecoder( attention_context=self.attention_context, bidirectional=False, - decoder_input_size=( - self.source_encoder.output_size - + self.features_encoder.output_size - if self.has_features_encoder - else self.source_encoder.output_size - ), + decoder_input_size=self.decoder_input_size, dropout=self.dropout, hidden_size=self.hidden_size, embeddings=self.embeddings, @@ -539,12 +542,7 @@ def get_decoder(self): else: return modules.HardAttentionLSTMDecoder( bidirectional=False, - decoder_input_size=( - self.source_encoder.output_size - + self.features_encoder.output_size - if self.has_features_encoder - else self.source_encoder.output_size - ), + decoder_input_size=self.decoder_input_size, dropout=self.dropout, embeddings=self.embeddings, embedding_size=self.embedding_size,