diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 48607aea..ecc44a25 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -274,7 +274,6 @@ class Decoder(nn.Module): self.attention.init_states(inputs) outputs, stop_tokens, alignments, t = [], [], [], 0 - stop_flags = [True, False, False] while True: memory = self.prenet(memory) if speaker_embeddings is not None: @@ -285,11 +284,7 @@ class Decoder(nn.Module): stop_tokens += [stop_token] alignments += [alignment] - stop_flags[0] = stop_flags[0] or stop_token > 0.5 - stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 - and t > inputs.shape[1]) - stop_flags[2] = t > inputs.shape[1] * 2 - if all(stop_flags): + if stop_token > 0.7: break if len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") @@ -325,11 +320,7 @@ class Decoder(nn.Module): stop_tokens += [stop_token] alignments += [alignment] - stop_flags[0] = stop_flags[0] or stop_token > 0.5 - stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 - and t > inputs.shape[1]) - stop_flags[2] = t > inputs.shape[1] * 2 - if all(stop_flags): + if stop_token > 0.7: break if len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") diff --git a/models/tacotron2.py b/models/tacotron2.py index c885b8ed..70dd31a5 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -84,7 +84,7 @@ class Tacotron2(nn.Module): encoder_outputs = self.encoder.inference(embedded_inputs) encoder_outputs = self._add_speaker_embedding(encoder_outputs, speaker_ids) - mel_outputs, stop_tokens, alignments = self.decoder.inference( + mel_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet @@ -100,7 +100,7 @@ class Tacotron2(nn.Module): encoder_outputs = self.encoder.inference_truncated(embedded_inputs) encoder_outputs = self._add_speaker_embedding(encoder_outputs, speaker_ids) - mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated( + mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated( encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet diff --git a/utils/synthesis.py b/utils/synthesis.py index f657eb4d..80d5272f 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -30,16 +30,12 @@ def compute_style_mel(style_wav, ap, use_cuda): def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): - if CONFIG.model == "TacotronGST" and style_mel is not None: + if CONFIG.use_gst: decoder_output, postnet_output, alignments, stop_tokens = model.inference( inputs, style_mel=style_mel, speaker_ids=speaker_id) else: - if truncated: - decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( - inputs, speaker_ids=speaker_id) - else: - decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, speaker_ids=speaker_id) + decoder_output, postnet_output, alignments, stop_tokens = model.inference( + inputs, speaker_ids=speaker_id) return decoder_output, postnet_output, alignments, stop_tokens