From efcfc1cf10266dc41ed08fe9f7c88bdbff8f94d4 Mon Sep 17 00:00:00 2001 From: Sasha Rush Date: Wed, 11 Apr 2018 13:03:51 -0400 Subject: [PATCH] Fix softmaxes --- onmt/ModelConstructor.py | 2 +- onmt/modules/GlobalAttention.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onmt/ModelConstructor.py b/onmt/ModelConstructor.py index 2d60626cfb..3c417c903c 100644 --- a/onmt/ModelConstructor.py +++ b/onmt/ModelConstructor.py @@ -191,7 +191,7 @@ def make_base_model(model_opt, fields, gpu, checkpoint=None): if not model_opt.copy_attn: generator = nn.Sequential( nn.Linear(model_opt.rnn_size, len(fields["tgt"].vocab)), - nn.LogSoftmax()) + nn.LogSoftmax(dim=-1)) if model_opt.share_decoder_embeddings: generator[0].weight = decoder.embeddings.word_lut.weight else: diff --git a/onmt/modules/GlobalAttention.py b/onmt/modules/GlobalAttention.py index 228f0658d7..8c74b2c664 100644 --- a/onmt/modules/GlobalAttention.py +++ b/onmt/modules/GlobalAttention.py @@ -76,7 +76,7 @@ def __init__(self, dim, coverage=False, attn_type="dot"): out_bias = self.attn_type == "mlp" self.linear_out = nn.Linear(dim*2, dim, bias=out_bias) - self.sm = nn.Softmax() + self.sm = nn.Softmax(dim=-1) self.tanh = nn.Tanh() if coverage: