From fb9289d365ab331aabd50013a8701421ed0fa416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 26 May 2021 16:03:56 +0200 Subject: [PATCH] update `synthesis.py` for being more generic --- TTS/tts/utils/synthesis.py | 58 ++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 4c3331c8..67432320 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -152,16 +152,6 @@ def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None): return decoder_output, postnet_output, None, None -def parse_outputs_torch(postnet_output, decoder_output, alignments, - stop_tokens): - postnet_output = postnet_output[0].data.cpu().numpy() - decoder_output = None if decoder_output is None else decoder_output[ - 0].data.cpu().numpy() - alignment = alignments[0].cpu().data.numpy() - stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy() - return postnet_output, decoder_output, alignment, stop_tokens - - def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens): postnet_output = postnet_output[0].numpy() decoder_output = decoder_output[0].numpy() @@ -200,8 +190,8 @@ def speaker_id_to_torch(speaker_id, cuda=False): def embedding_to_torch(x_vector, cuda=False): if x_vector is not None: x_vector = np.asarray(x_vector) - x_vector = torch.from_numpy(x_vector).unsqueeze( - 0).type(torch.FloatTensor) + x_vector = torch.from_numpy(x_vector).unsqueeze(0).type( + torch.FloatTensor) if cuda: return x_vector.cuda() return x_vector @@ -263,57 +253,59 @@ def synthesis( else: style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) # preprocess the given text - inputs = text_to_seq(text, CONFIG) + text_inputs = text_to_seq(text, CONFIG) # pass tensors to backend if backend == "torch": if speaker_id is not None: speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda) if x_vector is not None: - x_vector = embedding_to_torch(x_vector, - cuda=use_cuda) + x_vector = embedding_to_torch(x_vector, cuda=use_cuda) if not isinstance(style_mel, dict): style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) - inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda) - inputs = inputs.unsqueeze(0) + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) + text_inputs = text_inputs.unsqueeze(0) elif backend == "tf": # TODO: handle speaker id for tf model style_mel = numpy_to_tf(style_mel, tf.float32) - inputs = numpy_to_tf(inputs, tf.int32) - inputs = tf.expand_dims(inputs, 0) + text_inputs = numpy_to_tf(text_inputs, tf.int32) + text_inputs = tf.expand_dims(text_inputs, 0) elif backend == "tflite": style_mel = numpy_to_tf(style_mel, tf.float32) - inputs = numpy_to_tf(inputs, tf.int32) - inputs = tf.expand_dims(inputs, 0) + text_inputs = numpy_to_tf(text_inputs, tf.int32) + text_inputs = tf.expand_dims(text_inputs, 0) # synthesize voice if backend == "torch": outputs = run_model_torch(model, - inputs, + text_inputs, speaker_id, style_mel, x_vector=x_vector) - postnet_output, decoder_output, alignments, stop_tokens = \ - outputs['postnet_outputs'], outputs['decoder_outputs'],\ - outputs['alignments'], outputs['stop_tokens'] - postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch( - postnet_output, decoder_output, alignments, stop_tokens) + model_outputs = outputs['model_outputs'] + model_outputs = model_outputs[0].data.cpu().numpy() elif backend == "tf": decoder_output, postnet_output, alignments, stop_tokens = run_model_tf( - model, inputs, CONFIG, speaker_id, style_mel) - postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf( + model, text_inputs, CONFIG, speaker_id, style_mel) + model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf( postnet_output, decoder_output, alignments, stop_tokens) elif backend == "tflite": decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite( - model, inputs, CONFIG, speaker_id, style_mel) - postnet_output, decoder_output = parse_outputs_tflite( + model, text_inputs, CONFIG, speaker_id, style_mel) + model_outputs, decoder_output = parse_outputs_tflite( postnet_output, decoder_output) # convert outputs to numpy # plot results wav = None if use_griffin_lim: - wav = inv_spectrogram(postnet_output, ap, CONFIG) + wav = inv_spectrogram(model_outputs, ap, CONFIG) # trim silence if do_trim_silence: wav = trim_silence(wav, ap) - return wav, alignment, decoder_output, postnet_output, stop_tokens, inputs + return_dict = { + 'wav': wav, + 'alignments': outputs['alignments'], + 'model_outputs': model_outputs, + 'text_inputs': text_inputs + } + return return_dict