duration predictor fix 2

pull/373/head
Eren Gölge 2021-03-16 17:07:15 +01:00
parent 07269e639b
commit aec0b78aff
1 changed files with 2 additions and 3 deletions

View File

@ -196,7 +196,6 @@ class AlignTTS(nn.Module):
o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
# decoder pass
o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de, attn.transpose(1, 2)
# def _forward_mas(self, o_en, y, y_lengths, x_mask):
@ -225,8 +224,8 @@ class AlignTTS(nn.Module):
g: [B, C]
"""
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(x, x_mask)
dr_mas, mu, log_sigma, logp_max_path = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
# TODO: compute attn once
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
dr_mas_log = torch.log(1 + dr_mas).squeeze(1)
@ -242,8 +241,8 @@ class AlignTTS(nn.Module):
# pad input to prevent dropping the last word
x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(x, x_mask)
# duration predictor pass
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)