mirror of https://github.com/coqui-ai/TTS.git
Add a constant attnetion model type to attention class
parent
819011e1a2
commit
14f9d06b31
|
@ -219,7 +219,7 @@ class Decoder(nn.Module):
|
|||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = AttentionRNN(256, in_features, 128)
|
||||
self.attention_rnn = AttentionRNN(256, in_features, 128, align_model='ls')
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
|
@ -257,14 +257,15 @@ class Decoder(nn.Module):
|
|||
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
||||
self.memory_dim, self.r)
|
||||
T_decoder = memory.size(1)
|
||||
# go frame - 0 frames tarting the sequence
|
||||
# go frame as zeros matrix
|
||||
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
||||
# Init decoder states
|
||||
# decoder states
|
||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
||||
for _ in range(len(self.decoder_rnns))]
|
||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||
# attention states
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
attention_cum = inputs.data.new(B, T).zero_()
|
||||
# Time first (T_decoder, B, memory_dim)
|
||||
|
|
Loading…
Reference in New Issue