mirror of https://github.com/coqui-ai/TTS.git
duration predictor fix 2
parent
07269e639b
commit
aec0b78aff
|
@ -196,7 +196,6 @@ class AlignTTS(nn.Module):
|
|||
o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
|
||||
# decoder pass
|
||||
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
||||
|
||||
return o_de, attn.transpose(1, 2)
|
||||
|
||||
# def _forward_mas(self, o_en, y, y_lengths, x_mask):
|
||||
|
@ -225,8 +224,8 @@ class AlignTTS(nn.Module):
|
|||
g: [B, C]
|
||||
"""
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(x, x_mask)
|
||||
dr_mas, mu, log_sigma, logp_max_path = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
# TODO: compute attn once
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||
dr_mas_log = torch.log(1 + dr_mas).squeeze(1)
|
||||
|
@ -242,8 +241,8 @@ class AlignTTS(nn.Module):
|
|||
# pad input to prevent dropping the last word
|
||||
x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(x, x_mask)
|
||||
# duration predictor pass
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
y_lengths = o_dr.sum(1)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||
|
|
Loading…
Reference in New Issue