pull/10/head
Eren Golge 2019-04-18 18:55:37 +02:00
parent 9ba13b2d2f
commit 38213dffe9
1 changed files with 1 additions and 1 deletions

View File

@ -183,7 +183,7 @@ class Attention(nn.Module):
def apply_forward_attention(self, inputs, alignment, processed_query):
# forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha)) * alignment
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
# compute context
context = torch.bmm(self.alpha.unsqueeze(1), inputs)