bug fix for synthesis.py

pull/10/head
Eren Golge 2019-10-29 17:38:59 +01:00
parent 002991ca15
commit 60b6ec18fe
3 changed files with 7 additions and 20 deletions

View File

@ -274,7 +274,6 @@ class Decoder(nn.Module):
self.attention.init_states(inputs)
outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [True, False, False]
while True:
memory = self.prenet(memory)
if speaker_embeddings is not None:
@ -285,11 +284,7 @@ class Decoder(nn.Module):
stop_tokens += [stop_token]
alignments += [alignment]
stop_flags[0] = stop_flags[0] or stop_token > 0.5
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8
and t > inputs.shape[1])
stop_flags[2] = t > inputs.shape[1] * 2
if all(stop_flags):
if stop_token > 0.7:
break
if len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
@ -325,11 +320,7 @@ class Decoder(nn.Module):
stop_tokens += [stop_token]
alignments += [alignment]
stop_flags[0] = stop_flags[0] or stop_token > 0.5
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8
and t > inputs.shape[1])
stop_flags[2] = t > inputs.shape[1] * 2
if all(stop_flags):
if stop_token > 0.7:
break
if len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")

View File

@ -84,7 +84,7 @@ class Tacotron2(nn.Module):
encoder_outputs = self.encoder.inference(embedded_inputs)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
mel_outputs, stop_tokens, alignments = self.decoder.inference(
mel_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
@ -100,7 +100,7 @@ class Tacotron2(nn.Module):
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet

View File

@ -30,16 +30,12 @@ def compute_style_mel(style_wav, ap, use_cuda):
def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
if CONFIG.model == "TacotronGST" and style_mel is not None:
if CONFIG.use_gst:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, style_mel=style_mel, speaker_ids=speaker_id)
else:
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id)
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id)
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id)
return decoder_output, postnet_output, alignments, stop_tokens