Add a constant attnetion model type to attention class

pull/10/head
Eren Golge 2018-05-23 06:18:09 -07:00
parent 819011e1a2
commit 14f9d06b31
1 changed files with 4 additions and 3 deletions

View File

@ -219,7 +219,7 @@ class Decoder(nn.Module):
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNN(256, in_features, 128)
self.attention_rnn = AttentionRNN(256, in_features, 128, align_model='ls')
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state
@ -257,14 +257,15 @@ class Decoder(nn.Module):
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
self.memory_dim, self.r)
T_decoder = memory.size(1)
# go frame - 0 frames tarting the sequence
# go frame as zeros matrix
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
# Init decoder states
# decoder states
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
for _ in range(len(self.decoder_rnns))]
current_context_vec = inputs.data.new(B, 256).zero_()
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
# attention states
attention = inputs.data.new(B, T).zero_()
attention_cum = inputs.data.new(B, T).zero_()
# Time first (T_decoder, B, memory_dim)