Add attention-cum

pull/10/head
Eren G 2018-07-17 15:59:18 +02:00
parent adbe603af1
commit 90d7e885e7
1 changed files with 6 additions and 6 deletions

View File

@ -252,7 +252,7 @@ class Decoder(nn.Module):
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
# attention states
attention = inputs.data.new(B, T).zero_()
# attention_cum = inputs.data.new(B, T).zero_()
attention_cum = inputs.data.new(B, T).zero_()
# Time first (T_decoder, B, memory_dim)
if memory is not None:
memory = memory.transpose(0, 1)
@ -270,13 +270,13 @@ class Decoder(nn.Module):
# Prenet
processed_memory = self.prenet(memory_input)
# Attention RNN
# attention_cat = torch.cat((attention.unsqueeze(1),
# attention_cum.unsqueeze(1)),
# dim=1)
attention_cat = torch.cat((attention.unsqueeze(1),
attention_cum.unsqueeze(1)),
dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention.unsqueeze(1), input_lens)
# attention_cum += attention
inputs, attention_cat, input_lens)
attention_cum += attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_context_vec), -1))