renamed query_rnn back to attention_rnn

pull/10/head
Thomas Werkmeister 2019-07-24 11:46:34 +02:00
parent 5db302179a
commit a6118564d5
2 changed files with 16 additions and 14 deletions

View File

@ -291,7 +291,9 @@ class Decoder(nn.Module):
prenet_dropout,
out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.query_rnn = nn.GRUCell(in_features + 128, self.query_dim)
# attention_rnn generates queries for the attention mechanism
self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim)
self.attention = Attention(query_dim=self.query_dim,
embedding_dim=in_features,
attention_dim=128,
@ -311,7 +313,7 @@ class Decoder(nn.Module):
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r)
# learn init values instead of zero init.
self.query_rnn_init = nn.Embedding(1, 256)
self.attention_rnn_init = nn.Embedding(1, 256)
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
self.decoder_rnn_inits = nn.Embedding(2, 256)
self.stopnet = StopNet(256 + memory_dim * r)
@ -348,7 +350,7 @@ class Decoder(nn.Module):
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
# decoder states
self.query = self.query_rnn_init(
self.query = self.attention_rnn_init(
inputs.data.new_zeros(B).long())
self.decoder_rnn_hiddens = [
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
@ -369,8 +371,8 @@ class Decoder(nn.Module):
# Prenet
processed_memory = self.prenet(self.memory_input)
# Attention RNN
self.query = self.query_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
# Attention
self.query = self.attention_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask)
# Concat query and attention context vector

View File

@ -116,8 +116,8 @@ class Decoder(nn.Module):
prenet_dropout,
[self.prenet_dim, self.prenet_dim], bias=False)
self.query_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.query_dim)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.query_dim)
self.attention = Attention(query_dim=self.query_dim,
embedding_dim=in_features,
@ -145,7 +145,7 @@ class Decoder(nn.Module):
bias=True,
init_gain='sigmoid'))
self.query_rnn_init = nn.Embedding(1, self.query_dim)
self.attention_rnn_init = nn.Embedding(1, self.query_dim)
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
self.memory_truncated = None
@ -160,9 +160,9 @@ class Decoder(nn.Module):
# T = inputs.size(1)
if not keep_states:
self.query = self.query_rnn_init(
self.query = self.attention_rnn_init(
inputs.data.new_zeros(B).long())
self.query_rnn_cell_state = Variable(
self.attention_rnn_cell_state = Variable(
inputs.data.new(B, self.query_dim).zero_())
self.decoder_hidden = self.decoder_rnn_inits(
@ -194,12 +194,12 @@ class Decoder(nn.Module):
def decode(self, memory):
query_input = torch.cat((memory, self.context), -1)
self.query, self.query_rnn_cell_state = self.query_rnn(
query_input, (self.query, self.query_rnn_cell_state))
self.query, self.attention_rnn_cell_state = self.attention_rnn(
query_input, (self.query, self.attention_rnn_cell_state))
self.query = F.dropout(
self.query, self.p_attention_dropout, self.training)
self.query_rnn_cell_state = F.dropout(
self.query_rnn_cell_state, self.p_attention_dropout, self.training)
self.attention_rnn_cell_state = F.dropout(
self.attention_rnn_cell_state, self.p_attention_dropout, self.training)
self.context = self.attention(self.query, self.inputs,
self.processed_inputs, self.mask)