mirror of https://github.com/coqui-ai/TTS.git
Do not avg cummulative attention weight
parent
fe99baec5a
commit
ad943120ae
|
@ -286,7 +286,7 @@ class Decoder(nn.Module):
|
|||
processed_memory = self.prenet(memory_input)
|
||||
# Attention RNN
|
||||
attention_cat = torch.cat((attention.unsqueeze(1),
|
||||
attention_cum.unsqueeze(1) / (t + 1)),
|
||||
attention_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_cat)
|
||||
|
|
Loading…
Reference in New Issue