mirror of https://github.com/coqui-ai/TTS.git
bug fix for synthesis.py
parent
002991ca15
commit
60b6ec18fe
|
@ -274,7 +274,6 @@ class Decoder(nn.Module):
|
||||||
self.attention.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
|
|
||||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||||
stop_flags = [True, False, False]
|
|
||||||
while True:
|
while True:
|
||||||
memory = self.prenet(memory)
|
memory = self.prenet(memory)
|
||||||
if speaker_embeddings is not None:
|
if speaker_embeddings is not None:
|
||||||
|
@ -285,11 +284,7 @@ class Decoder(nn.Module):
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
if stop_token > 0.7:
|
||||||
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8
|
|
||||||
and t > inputs.shape[1])
|
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
|
||||||
if all(stop_flags):
|
|
||||||
break
|
break
|
||||||
if len(outputs) == self.max_decoder_steps:
|
if len(outputs) == self.max_decoder_steps:
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
|
@ -325,11 +320,7 @@ class Decoder(nn.Module):
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
if stop_token > 0.7:
|
||||||
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8
|
|
||||||
and t > inputs.shape[1])
|
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
|
||||||
if all(stop_flags):
|
|
||||||
break
|
break
|
||||||
if len(outputs) == self.max_decoder_steps:
|
if len(outputs) == self.max_decoder_steps:
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
|
|
|
@ -84,7 +84,7 @@ class Tacotron2(nn.Module):
|
||||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||||
speaker_ids)
|
speaker_ids)
|
||||||
mel_outputs, stop_tokens, alignments = self.decoder.inference(
|
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
@ -100,7 +100,7 @@ class Tacotron2(nn.Module):
|
||||||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||||
speaker_ids)
|
speaker_ids)
|
||||||
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(
|
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
|
|
@ -30,16 +30,12 @@ def compute_style_mel(style_wav, ap, use_cuda):
|
||||||
|
|
||||||
|
|
||||||
def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||||
if CONFIG.model == "TacotronGST" and style_mel is not None:
|
if CONFIG.use_gst:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||||
inputs, style_mel=style_mel, speaker_ids=speaker_id)
|
inputs, style_mel=style_mel, speaker_ids=speaker_id)
|
||||||
else:
|
else:
|
||||||
if truncated:
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
inputs, speaker_ids=speaker_id)
|
||||||
inputs, speaker_ids=speaker_id)
|
|
||||||
else:
|
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
|
||||||
inputs, speaker_ids=speaker_id)
|
|
||||||
return decoder_output, postnet_output, alignments, stop_tokens
|
return decoder_output, postnet_output, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue