bug fix and var renaming

pull/1/head
erogol 2020-06-08 03:22:17 +02:00
parent fedb2542be
commit 2404f96cba
1 changed files with 7 additions and 6 deletions

View File

@ -107,6 +107,7 @@ class Tacotron2(TacotronAbstract):
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
# B x mel_dim x T_out
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs
# sequence masking
if output_mask is not None:
postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
@ -130,13 +131,13 @@ class Tacotron2(TacotronAbstract):
if self.num_speakers > 1:
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
self.speaker_embeddings)
mel_outputs, alignments, stop_tokens = self.decoder.inference(
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
decoder_outputs, postnet_outputs, alignments)
return decoder_outputs, postnet_outputs, alignments, stop_tokens
def inference_truncated(self, text, speaker_ids=None):
"""