mirror of https://github.com/coqui-ai/TTS.git
bug fix for synthesis.py
parent
002991ca15
commit
60b6ec18fe
|
@ -274,7 +274,6 @@ class Decoder(nn.Module):
|
|||
self.attention.init_states(inputs)
|
||||
|
||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||
stop_flags = [True, False, False]
|
||||
while True:
|
||||
memory = self.prenet(memory)
|
||||
if speaker_embeddings is not None:
|
||||
|
@ -285,11 +284,7 @@ class Decoder(nn.Module):
|
|||
stop_tokens += [stop_token]
|
||||
alignments += [alignment]
|
||||
|
||||
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
||||
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8
|
||||
and t > inputs.shape[1])
|
||||
stop_flags[2] = t > inputs.shape[1] * 2
|
||||
if all(stop_flags):
|
||||
if stop_token > 0.7:
|
||||
break
|
||||
if len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
|
@ -325,11 +320,7 @@ class Decoder(nn.Module):
|
|||
stop_tokens += [stop_token]
|
||||
alignments += [alignment]
|
||||
|
||||
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
||||
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8
|
||||
and t > inputs.shape[1])
|
||||
stop_flags[2] = t > inputs.shape[1] * 2
|
||||
if all(stop_flags):
|
||||
if stop_token > 0.7:
|
||||
break
|
||||
if len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
|
|
|
@ -84,7 +84,7 @@ class Tacotron2(nn.Module):
|
|||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
speaker_ids)
|
||||
mel_outputs, stop_tokens, alignments = self.decoder.inference(
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
|
@ -100,7 +100,7 @@ class Tacotron2(nn.Module):
|
|||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
speaker_ids)
|
||||
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
||||
encoder_outputs)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
|
|
|
@ -30,16 +30,12 @@ def compute_style_mel(style_wav, ap, use_cuda):
|
|||
|
||||
|
||||
def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
if CONFIG.model == "TacotronGST" and style_mel is not None:
|
||||
if CONFIG.use_gst:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, style_mel=style_mel, speaker_ids=speaker_id)
|
||||
else:
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
inputs, speaker_ids=speaker_id)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, speaker_ids=speaker_id)
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, speaker_ids=speaker_id)
|
||||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue