mirror of https://github.com/coqui-ai/TTS.git
Add attention-cum
parent
adbe603af1
commit
90d7e885e7
|
@ -252,7 +252,7 @@ class Decoder(nn.Module):
|
|||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||
# attention states
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
# attention_cum = inputs.data.new(B, T).zero_()
|
||||
attention_cum = inputs.data.new(B, T).zero_()
|
||||
# Time first (T_decoder, B, memory_dim)
|
||||
if memory is not None:
|
||||
memory = memory.transpose(0, 1)
|
||||
|
@ -270,13 +270,13 @@ class Decoder(nn.Module):
|
|||
# Prenet
|
||||
processed_memory = self.prenet(memory_input)
|
||||
# Attention RNN
|
||||
# attention_cat = torch.cat((attention.unsqueeze(1),
|
||||
# attention_cum.unsqueeze(1)),
|
||||
# dim=1)
|
||||
attention_cat = torch.cat((attention.unsqueeze(1),
|
||||
attention_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||
inputs, attention.unsqueeze(1), input_lens)
|
||||
# attention_cum += attention
|
||||
inputs, attention_cat, input_lens)
|
||||
attention_cum += attention
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||
|
|
Loading…
Reference in New Issue