diff --git a/layers/common_layers.py b/layers/common_layers.py index 6ddd0b6b..716bcfd0 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -119,9 +119,9 @@ class GravesAttention(nn.Module): self.epsilon = 1e-5 self.J = None self.N_a = nn.Sequential( - nn.Linear(query_dim, query_dim//2), + nn.Linear(query_dim, query_dim), nn.Tanh(), - nn.Linear(query_dim//2, 3*K)) + nn.Linear(query_dim, 3*K)) self.attention_weights = None self.mu_prev = None @@ -157,8 +157,10 @@ class GravesAttention(nn.Module): # mu_t = self.mu_prev + self.attention_alignment * torch.exp(k_t) # mean sig_t = torch.pow(torch.nn.functional.softplus(b_t), 2) mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) + # TODO try sigmoid here g_t = (torch.softmax(g_t, dim=-1) / sig_t) * self.COEF + # each B x K x T_in g_t = g_t.unsqueeze(2).expand(g_t.size(0), g_t.size(1), inputs.size(1))