diff --git a/layers/common_layers.py b/layers/common_layers.py index 5365d605..a768e684 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -136,8 +136,8 @@ class GravesAttention(nn.Module): torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) def init_states(self, inputs): - if self.J is None or inputs.shape[1] > self.J.shape[-1]: - self.J = torch.arange(0, inputs.shape[1]+1).to(inputs.device) + 0.5 + if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]: + self.J = torch.arange(0, inputs.shape[1]+2).to(inputs.device) + 0.5 self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) @@ -165,24 +165,25 @@ class GravesAttention(nn.Module): # attention GMM parameters sig_t = torch.nn.functional.softplus(b_t) + self.eps - mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps - j = self.J[:inputs.size(1)+1] # attention weights - phi_t = g_t.unsqueeze(-1) * torch.exp(-0.5 * (mu_t.unsqueeze(-1) - j)**2 / (sig_t.unsqueeze(-1)**2)) + phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.exp((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1)))) # discritize attention weights - alpha_t = self.COEF * torch.sum(phi_t, 1) + alpha_t = torch.sum(phi_t, 1) alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1] + alpha_t[alpha_t == 0] = 1e-8 # apply masking if mask is not None: alpha_t.data.masked_fill_(~mask, self._mask_value) context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) + # for better visualization + # self.attention_weights = torch.clamp(alpha_t, min=0) self.attention_weights = alpha_t self.mu_prev = mu_t return context @@ -355,7 +356,7 @@ class OriginalAttention(nn.Module): if self.forward_attn: alignment = self.apply_forward_attention(alignment) self.alpha = alignment - + context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) self.attention_weights = alignment