Skip to content

Commit

Permalink
Merge pull request #287 from CUNY-CL/unused
Browse files Browse the repository at this point in the history
Removes unused `decode_step` in abstract class
  • Loading branch information
kylebgorman authored Dec 9, 2024
2 parents bd6f171 + a4a7dbf commit 7576772
Showing 1 changed file with 0 additions and 64 deletions.
64 changes: 0 additions & 64 deletions yoyodyne/models/pointer_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,70 +303,6 @@ def greedy_decode(
predictions = torch.stack(predictions).transpose(0, 1)
return predictions

def decode_step(
self,
symbol: torch.Tensor,
last_hiddens: Tuple[torch.Tensor, torch.Tensor],
source_indices: torch.Tensor,
source_enc: torch.Tensor,
source_mask: torch.Tensor,
features_enc: Optional[torch.Tensor] = None,
features_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Runs a single step of the decoder.
This predicts a distribution for one symbol.
Args:
symbol (torch.Tensor).
last_hiddens (Tuple[torch.Tensor, torch.Tensor]).
source_indices (torch.Tensor).
source_enc (torch.Tensor).
source_mask (torch.Tensor).
features_enc (Optional[torch.Tensor]).
features_mask (Optional[torch.Tensor]).
Returns:
Tuple[torch.Tensor, torch.Tensor].
"""
embedded = self.decoder.embed(symbol)
last_h0, last_c0 = last_hiddens
source_context, attention_weights = self.decoder.attention(
last_h0.transpose(0, 1), source_enc, source_mask
)
if self.has_features_encoder:
features_context, _ = self.features_attention(
last_h0.transpose(0, 1), features_enc, features_mask
)
# -> B x 1 x 4*hidden_size.
context = torch.cat([source_context, features_context], dim=2)
else:
context = source_context
_, (h, c) = self.decoder.module(
torch.cat((embedded, context), 2), (last_h0, last_c0)
)
# -> B x 1 x hidden_size
hidden = h[-1, :, :].unsqueeze(1)
output_dist = self.classifier(torch.cat([hidden, context], dim=2))
output_dist = nn.functional.softmax(output_dist, dim=2)
# -> B x 1 x target_vocab_size.
ptr_dist = torch.zeros(
symbol.size(0),
self.target_vocab_size,
device=self.device,
dtype=attention_weights.dtype,
).unsqueeze(1)
# Gets the attentions to the source in terms of the output generations.
# These are the "pointer" distribution.
ptr_dist.scatter_add_(
2, source_indices.unsqueeze(1), attention_weights
)
# Probability of generating (from output_dist).
gen_probs = self.generation_probability(context, hidden, embedded)
scaled_ptr_dist = ptr_dist * (1 - gen_probs)
scaled_output_dist = output_dist * gen_probs
return torch.log(scaled_output_dist + scaled_ptr_dist), (h, c)


class PointerGeneratorGRUModel(PointerGeneratorRNNModel, rnn.GRUModel):
"""Pointer-generator model with an GRU backend."""
Expand Down

0 comments on commit 7576772

Please sign in to comment.