diff --git a/yoyodyne/models/pointer_generator.py b/yoyodyne/models/pointer_generator.py index 93f9bc4..1488049 100644 --- a/yoyodyne/models/pointer_generator.py +++ b/yoyodyne/models/pointer_generator.py @@ -303,70 +303,6 @@ def greedy_decode( predictions = torch.stack(predictions).transpose(0, 1) return predictions - def decode_step( - self, - symbol: torch.Tensor, - last_hiddens: Tuple[torch.Tensor, torch.Tensor], - source_indices: torch.Tensor, - source_enc: torch.Tensor, - source_mask: torch.Tensor, - features_enc: Optional[torch.Tensor] = None, - features_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Runs a single step of the decoder. - - This predicts a distribution for one symbol. - - Args: - symbol (torch.Tensor). - last_hiddens (Tuple[torch.Tensor, torch.Tensor]). - source_indices (torch.Tensor). - source_enc (torch.Tensor). - source_mask (torch.Tensor). - features_enc (Optional[torch.Tensor]). - features_mask (Optional[torch.Tensor]). - - Returns: - Tuple[torch.Tensor, torch.Tensor]. - """ - embedded = self.decoder.embed(symbol) - last_h0, last_c0 = last_hiddens - source_context, attention_weights = self.decoder.attention( - last_h0.transpose(0, 1), source_enc, source_mask - ) - if self.has_features_encoder: - features_context, _ = self.features_attention( - last_h0.transpose(0, 1), features_enc, features_mask - ) - # -> B x 1 x 4*hidden_size. - context = torch.cat([source_context, features_context], dim=2) - else: - context = source_context - _, (h, c) = self.decoder.module( - torch.cat((embedded, context), 2), (last_h0, last_c0) - ) - # -> B x 1 x hidden_size - hidden = h[-1, :, :].unsqueeze(1) - output_dist = self.classifier(torch.cat([hidden, context], dim=2)) - output_dist = nn.functional.softmax(output_dist, dim=2) - # -> B x 1 x target_vocab_size. - ptr_dist = torch.zeros( - symbol.size(0), - self.target_vocab_size, - device=self.device, - dtype=attention_weights.dtype, - ).unsqueeze(1) - # Gets the attentions to the source in terms of the output generations. - # These are the "pointer" distribution. - ptr_dist.scatter_add_( - 2, source_indices.unsqueeze(1), attention_weights - ) - # Probability of generating (from output_dist). - gen_probs = self.generation_probability(context, hidden, embedded) - scaled_ptr_dist = ptr_dist * (1 - gen_probs) - scaled_output_dist = output_dist * gen_probs - return torch.log(scaled_output_dist + scaled_ptr_dist), (h, c) - class PointerGeneratorGRUModel(PointerGeneratorRNNModel, rnn.GRUModel): """Pointer-generator model with an GRU backend."""