mirror of https://github.com/coqui-ai/TTS.git
renamed query_rnn back to attention_rnn
parent
5db302179a
commit
a6118564d5
|
@ -291,7 +291,9 @@ class Decoder(nn.Module):
|
|||
prenet_dropout,
|
||||
out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.query_rnn = nn.GRUCell(in_features + 128, self.query_dim)
|
||||
# attention_rnn generates queries for the attention mechanism
|
||||
self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim)
|
||||
|
||||
self.attention = Attention(query_dim=self.query_dim,
|
||||
embedding_dim=in_features,
|
||||
attention_dim=128,
|
||||
|
@ -311,7 +313,7 @@ class Decoder(nn.Module):
|
|||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
# learn init values instead of zero init.
|
||||
self.query_rnn_init = nn.Embedding(1, 256)
|
||||
self.attention_rnn_init = nn.Embedding(1, 256)
|
||||
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
|
||||
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
||||
self.stopnet = StopNet(256 + memory_dim * r)
|
||||
|
@ -348,7 +350,7 @@ class Decoder(nn.Module):
|
|||
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
|
||||
|
||||
# decoder states
|
||||
self.query = self.query_rnn_init(
|
||||
self.query = self.attention_rnn_init(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.decoder_rnn_hiddens = [
|
||||
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
|
||||
|
@ -369,8 +371,8 @@ class Decoder(nn.Module):
|
|||
# Prenet
|
||||
processed_memory = self.prenet(self.memory_input)
|
||||
|
||||
# Attention RNN
|
||||
self.query = self.query_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
|
||||
# Attention
|
||||
self.query = self.attention_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
|
||||
self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask)
|
||||
|
||||
# Concat query and attention context vector
|
||||
|
|
|
@ -116,8 +116,8 @@ class Decoder(nn.Module):
|
|||
prenet_dropout,
|
||||
[self.prenet_dim, self.prenet_dim], bias=False)
|
||||
|
||||
self.query_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.query_dim)
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.query_dim)
|
||||
|
||||
self.attention = Attention(query_dim=self.query_dim,
|
||||
embedding_dim=in_features,
|
||||
|
@ -145,7 +145,7 @@ class Decoder(nn.Module):
|
|||
bias=True,
|
||||
init_gain='sigmoid'))
|
||||
|
||||
self.query_rnn_init = nn.Embedding(1, self.query_dim)
|
||||
self.attention_rnn_init = nn.Embedding(1, self.query_dim)
|
||||
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
||||
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
||||
self.memory_truncated = None
|
||||
|
@ -160,9 +160,9 @@ class Decoder(nn.Module):
|
|||
# T = inputs.size(1)
|
||||
|
||||
if not keep_states:
|
||||
self.query = self.query_rnn_init(
|
||||
self.query = self.attention_rnn_init(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.query_rnn_cell_state = Variable(
|
||||
self.attention_rnn_cell_state = Variable(
|
||||
inputs.data.new(B, self.query_dim).zero_())
|
||||
|
||||
self.decoder_hidden = self.decoder_rnn_inits(
|
||||
|
@ -194,12 +194,12 @@ class Decoder(nn.Module):
|
|||
|
||||
def decode(self, memory):
|
||||
query_input = torch.cat((memory, self.context), -1)
|
||||
self.query, self.query_rnn_cell_state = self.query_rnn(
|
||||
query_input, (self.query, self.query_rnn_cell_state))
|
||||
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
||||
query_input, (self.query, self.attention_rnn_cell_state))
|
||||
self.query = F.dropout(
|
||||
self.query, self.p_attention_dropout, self.training)
|
||||
self.query_rnn_cell_state = F.dropout(
|
||||
self.query_rnn_cell_state, self.p_attention_dropout, self.training)
|
||||
self.attention_rnn_cell_state = F.dropout(
|
||||
self.attention_rnn_cell_state, self.p_attention_dropout, self.training)
|
||||
|
||||
self.context = self.attention(self.query, self.inputs,
|
||||
self.processed_inputs, self.mask)
|
||||
|
|
Loading…
Reference in New Issue