Do not avg cummulative attention weight

pull/10/head
Eren Golge 2018-05-25 15:01:16 -07:00
parent fe99baec5a
commit ad943120ae
1 changed files with 1 additions and 1 deletions

View File

@ -286,7 +286,7 @@ class Decoder(nn.Module):
processed_memory = self.prenet(memory_input)
# Attention RNN
attention_cat = torch.cat((attention.unsqueeze(1),
attention_cum.unsqueeze(1) / (t + 1)),
attention_cum.unsqueeze(1)),
dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_cat)