Skip to content

Commit

Permalink
Adds helper for sizing the decoder input size.
Browse files Browse the repository at this point in the history
This eliminates a repeated ternary in the definition of the hard
attention classes simply by making it a property of the ABC.
  • Loading branch information
kylebgorman committed Dec 9, 2024
1 parent bd6f171 commit b79404c
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions yoyodyne/models/hard_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ def __init__(
self.teacher_forcing
), "Teacher forcing disabled but required by this model"

# Properties

@property
def decoder_input_size(self) -> int:
if self.has_features_encoder:
return (
self.source_encoder.output_size
+ self.features_encoder.output_size
)
else:
return self.source_encoder.output_size

# Implemented interface.

def init_decoding(
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
Expand Down Expand Up @@ -394,13 +408,17 @@ def _loss(
loss = -torch.logsumexp(fwd, dim=-1).mean() / target.size(1)
return loss

# Interface.

def get_decoder(self):
raise NotImplementedError

@property
def name(self) -> str:
raise NotImplementedError

# Flags.

@staticmethod
def add_argparse_args(parser: argparse.ArgumentParser) -> None:
"""Adds HMM configuration options to the argument parser.
Expand Down Expand Up @@ -437,12 +455,7 @@ def get_decoder(self):
return modules.ContextHardAttentionGRUDecoder(
attention_context=self.attention_context,
bidirectional=False,
decoder_input_size=(
self.source_encoder.output_size
+ self.features_encoder.output_size
if self.has_features_encoder
else self.source_encoder.output_size
),
decoder_input_size=self.decoder_input_size,
dropout=self.dropout,
embeddings=self.embeddings,
embedding_size=self.embedding_size,
Expand All @@ -453,12 +466,7 @@ def get_decoder(self):
else:
return modules.HardAttentionGRUDecoder(
bidirectional=False,
decoder_input_size=(
self.source_encoder.output_size
+ self.features_encoder.output_size
if self.has_features_encoder
else self.source_encoder.output_size
),
decoder_input_size=self.decoder_input_size,
dropout=self.dropout,
embedding_size=self.embedding_size,
embeddings=self.embeddings,
Expand Down Expand Up @@ -523,12 +531,7 @@ def get_decoder(self):
return modules.ContextHardAttentionLSTMDecoder(
attention_context=self.attention_context,
bidirectional=False,
decoder_input_size=(
self.source_encoder.output_size
+ self.features_encoder.output_size
if self.has_features_encoder
else self.source_encoder.output_size
),
decoder_input_size=self.decoder_input_size,
dropout=self.dropout,
hidden_size=self.hidden_size,
embeddings=self.embeddings,
Expand All @@ -539,12 +542,7 @@ def get_decoder(self):
else:
return modules.HardAttentionLSTMDecoder(
bidirectional=False,
decoder_input_size=(
self.source_encoder.output_size
+ self.features_encoder.output_size
if self.has_features_encoder
else self.source_encoder.output_size
),
decoder_input_size=self.decoder_input_size,
dropout=self.dropout,
embeddings=self.embeddings,
embedding_size=self.embedding_size,
Expand Down

0 comments on commit b79404c

Please sign in to comment.