mirror of https://github.com/coqui-ai/TTS.git
renamed query_rnn back to attention_rnn
parent
5db302179a
commit
a6118564d5
|
@ -291,7 +291,9 @@ class Decoder(nn.Module):
|
||||||
prenet_dropout,
|
prenet_dropout,
|
||||||
out_features=[256, 128])
|
out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.query_rnn = nn.GRUCell(in_features + 128, self.query_dim)
|
# attention_rnn generates queries for the attention mechanism
|
||||||
|
self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim)
|
||||||
|
|
||||||
self.attention = Attention(query_dim=self.query_dim,
|
self.attention = Attention(query_dim=self.query_dim,
|
||||||
embedding_dim=in_features,
|
embedding_dim=in_features,
|
||||||
attention_dim=128,
|
attention_dim=128,
|
||||||
|
@ -311,7 +313,7 @@ class Decoder(nn.Module):
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
# learn init values instead of zero init.
|
# learn init values instead of zero init.
|
||||||
self.query_rnn_init = nn.Embedding(1, 256)
|
self.attention_rnn_init = nn.Embedding(1, 256)
|
||||||
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
|
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
|
||||||
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
||||||
self.stopnet = StopNet(256 + memory_dim * r)
|
self.stopnet = StopNet(256 + memory_dim * r)
|
||||||
|
@ -348,7 +350,7 @@ class Decoder(nn.Module):
|
||||||
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
|
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
|
||||||
|
|
||||||
# decoder states
|
# decoder states
|
||||||
self.query = self.query_rnn_init(
|
self.query = self.attention_rnn_init(
|
||||||
inputs.data.new_zeros(B).long())
|
inputs.data.new_zeros(B).long())
|
||||||
self.decoder_rnn_hiddens = [
|
self.decoder_rnn_hiddens = [
|
||||||
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
|
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
|
||||||
|
@ -369,8 +371,8 @@ class Decoder(nn.Module):
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(self.memory_input)
|
processed_memory = self.prenet(self.memory_input)
|
||||||
|
|
||||||
# Attention RNN
|
# Attention
|
||||||
self.query = self.query_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
|
self.query = self.attention_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
|
||||||
self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask)
|
self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask)
|
||||||
|
|
||||||
# Concat query and attention context vector
|
# Concat query and attention context vector
|
||||||
|
|
|
@ -116,8 +116,8 @@ class Decoder(nn.Module):
|
||||||
prenet_dropout,
|
prenet_dropout,
|
||||||
[self.prenet_dim, self.prenet_dim], bias=False)
|
[self.prenet_dim, self.prenet_dim], bias=False)
|
||||||
|
|
||||||
self.query_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||||
self.query_dim)
|
self.query_dim)
|
||||||
|
|
||||||
self.attention = Attention(query_dim=self.query_dim,
|
self.attention = Attention(query_dim=self.query_dim,
|
||||||
embedding_dim=in_features,
|
embedding_dim=in_features,
|
||||||
|
@ -145,7 +145,7 @@ class Decoder(nn.Module):
|
||||||
bias=True,
|
bias=True,
|
||||||
init_gain='sigmoid'))
|
init_gain='sigmoid'))
|
||||||
|
|
||||||
self.query_rnn_init = nn.Embedding(1, self.query_dim)
|
self.attention_rnn_init = nn.Embedding(1, self.query_dim)
|
||||||
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
||||||
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
||||||
self.memory_truncated = None
|
self.memory_truncated = None
|
||||||
|
@ -160,9 +160,9 @@ class Decoder(nn.Module):
|
||||||
# T = inputs.size(1)
|
# T = inputs.size(1)
|
||||||
|
|
||||||
if not keep_states:
|
if not keep_states:
|
||||||
self.query = self.query_rnn_init(
|
self.query = self.attention_rnn_init(
|
||||||
inputs.data.new_zeros(B).long())
|
inputs.data.new_zeros(B).long())
|
||||||
self.query_rnn_cell_state = Variable(
|
self.attention_rnn_cell_state = Variable(
|
||||||
inputs.data.new(B, self.query_dim).zero_())
|
inputs.data.new(B, self.query_dim).zero_())
|
||||||
|
|
||||||
self.decoder_hidden = self.decoder_rnn_inits(
|
self.decoder_hidden = self.decoder_rnn_inits(
|
||||||
|
@ -194,12 +194,12 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
def decode(self, memory):
|
def decode(self, memory):
|
||||||
query_input = torch.cat((memory, self.context), -1)
|
query_input = torch.cat((memory, self.context), -1)
|
||||||
self.query, self.query_rnn_cell_state = self.query_rnn(
|
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
||||||
query_input, (self.query, self.query_rnn_cell_state))
|
query_input, (self.query, self.attention_rnn_cell_state))
|
||||||
self.query = F.dropout(
|
self.query = F.dropout(
|
||||||
self.query, self.p_attention_dropout, self.training)
|
self.query, self.p_attention_dropout, self.training)
|
||||||
self.query_rnn_cell_state = F.dropout(
|
self.attention_rnn_cell_state = F.dropout(
|
||||||
self.query_rnn_cell_state, self.p_attention_dropout, self.training)
|
self.attention_rnn_cell_state, self.p_attention_dropout, self.training)
|
||||||
|
|
||||||
self.context = self.attention(self.query, self.inputs,
|
self.context = self.attention(self.query, self.inputs,
|
||||||
self.processed_inputs, self.mask)
|
self.processed_inputs, self.mask)
|
||||||
|
|
Loading…
Reference in New Issue