mirror of https://github.com/coqui-ai/TTS.git
bug fix #2
parent
9ba13b2d2f
commit
38213dffe9
|
@ -183,7 +183,7 @@ class Attention(nn.Module):
|
|||
def apply_forward_attention(self, inputs, alignment, processed_query):
|
||||
# forward attention
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
||||
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha)) * alignment
|
||||
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment
|
||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||
# compute context
|
||||
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
||||
|
|
Loading…
Reference in New Issue