From a6118564d578f314dfa787a80cc288dba2228dfa Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Wed, 24 Jul 2019 11:46:34 +0200 Subject: [PATCH] renamed query_rnn back to attention_rnn --- layers/tacotron.py | 12 +++++++----- layers/tacotron2.py | 18 +++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index 31d6cd84..40225fa5 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -291,7 +291,9 @@ class Decoder(nn.Module): prenet_dropout, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State - self.query_rnn = nn.GRUCell(in_features + 128, self.query_dim) + # attention_rnn generates queries for the attention mechanism + self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim) + self.attention = Attention(query_dim=self.query_dim, embedding_dim=in_features, attention_dim=128, @@ -311,7 +313,7 @@ class Decoder(nn.Module): # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) # learn init values instead of zero init. - self.query_rnn_init = nn.Embedding(1, 256) + self.attention_rnn_init = nn.Embedding(1, 256) self.memory_init = nn.Embedding(1, self.memory_size * memory_dim) self.decoder_rnn_inits = nn.Embedding(2, 256) self.stopnet = StopNet(256 + memory_dim * r) @@ -348,7 +350,7 @@ class Decoder(nn.Module): self.memory_input = self.memory_init(inputs.data.new_zeros(B).long()) # decoder states - self.query = self.query_rnn_init( + self.query = self.attention_rnn_init( inputs.data.new_zeros(B).long()) self.decoder_rnn_hiddens = [ self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long()) @@ -369,8 +371,8 @@ class Decoder(nn.Module): # Prenet processed_memory = self.prenet(self.memory_input) - # Attention RNN - self.query = self.query_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query) + # Attention + self.query = self.attention_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query) self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask) # Concat query and attention context vector diff --git a/layers/tacotron2.py b/layers/tacotron2.py index ba52abe2..358d1807 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -116,8 +116,8 @@ class Decoder(nn.Module): prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False) - self.query_rnn = nn.LSTMCell(self.prenet_dim + in_features, - self.query_dim) + self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, + self.query_dim) self.attention = Attention(query_dim=self.query_dim, embedding_dim=in_features, @@ -145,7 +145,7 @@ class Decoder(nn.Module): bias=True, init_gain='sigmoid')) - self.query_rnn_init = nn.Embedding(1, self.query_dim) + self.attention_rnn_init = nn.Embedding(1, self.query_dim) self.go_frame_init = nn.Embedding(1, self.mel_channels * r) self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim) self.memory_truncated = None @@ -160,9 +160,9 @@ class Decoder(nn.Module): # T = inputs.size(1) if not keep_states: - self.query = self.query_rnn_init( + self.query = self.attention_rnn_init( inputs.data.new_zeros(B).long()) - self.query_rnn_cell_state = Variable( + self.attention_rnn_cell_state = Variable( inputs.data.new(B, self.query_dim).zero_()) self.decoder_hidden = self.decoder_rnn_inits( @@ -194,12 +194,12 @@ class Decoder(nn.Module): def decode(self, memory): query_input = torch.cat((memory, self.context), -1) - self.query, self.query_rnn_cell_state = self.query_rnn( - query_input, (self.query, self.query_rnn_cell_state)) + self.query, self.attention_rnn_cell_state = self.attention_rnn( + query_input, (self.query, self.attention_rnn_cell_state)) self.query = F.dropout( self.query, self.p_attention_dropout, self.training) - self.query_rnn_cell_state = F.dropout( - self.query_rnn_cell_state, self.p_attention_dropout, self.training) + self.attention_rnn_cell_state = F.dropout( + self.attention_rnn_cell_state, self.p_attention_dropout, self.training) self.context = self.attention(self.query, self.inputs, self.processed_inputs, self.mask)