diff --git a/models/tacotron2.py b/models/tacotron2.py index bbce4be9..bce21e9e 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -107,6 +107,7 @@ class Tacotron2(TacotronAbstract): decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) # B x mel_dim x T_out postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = decoder_outputs + postnet_outputs # sequence masking if output_mask is not None: postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs) @@ -130,13 +131,13 @@ class Tacotron2(TacotronAbstract): if self.num_speakers > 1: encoder_outputs = self._add_speaker_embedding(encoder_outputs, self.speaker_embeddings) - mel_outputs, alignments, stop_tokens = self.decoder.inference( + decoder_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs) - mel_outputs_postnet = self.postnet(mel_outputs) - mel_outputs_postnet = mel_outputs + mel_outputs_postnet - mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( - mel_outputs, mel_outputs_postnet, alignments) - return mel_outputs, mel_outputs_postnet, alignments, stop_tokens + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = decoder_outputs + postnet_outputs + decoder_outputs, postnet_outputs, alignments = self.shape_outputs( + decoder_outputs, postnet_outputs, alignments) + return decoder_outputs, postnet_outputs, alignments, stop_tokens def inference_truncated(self, text, speaker_ids=None): """