Add attention-cum

pull/10/head
Eren G 2018-07-17 15:59:18 +02:00
parent adbe603af1
commit 90d7e885e7
1 changed files with 6 additions and 6 deletions

View File

@ -252,7 +252,7 @@ class Decoder(nn.Module):
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
# attention states # attention states
attention = inputs.data.new(B, T).zero_() attention = inputs.data.new(B, T).zero_()
# attention_cum = inputs.data.new(B, T).zero_() attention_cum = inputs.data.new(B, T).zero_()
# Time first (T_decoder, B, memory_dim) # Time first (T_decoder, B, memory_dim)
if memory is not None: if memory is not None:
memory = memory.transpose(0, 1) memory = memory.transpose(0, 1)
@ -270,13 +270,13 @@ class Decoder(nn.Module):
# Prenet # Prenet
processed_memory = self.prenet(memory_input) processed_memory = self.prenet(memory_input)
# Attention RNN # Attention RNN
# attention_cat = torch.cat((attention.unsqueeze(1), attention_cat = torch.cat((attention.unsqueeze(1),
# attention_cum.unsqueeze(1)), attention_cum.unsqueeze(1)),
# dim=1) dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention.unsqueeze(1), input_lens) inputs, attention_cat, input_lens)
# attention_cum += attention attention_cum += attention
# Concat RNN output and attention context vector # Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in( decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_context_vec), -1)) torch.cat((attention_rnn_hidden, current_context_vec), -1))