update `synthesis.py` for being more generic

pull/602/head
Eren Gölge 2021-05-26 16:03:56 +02:00
parent f121b0ff5d
commit fb9289d365
1 changed files with 25 additions and 33 deletions

View File

@ -152,16 +152,6 @@ def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
return decoder_output, postnet_output, None, None
def parse_outputs_torch(postnet_output, decoder_output, alignments,
stop_tokens):
postnet_output = postnet_output[0].data.cpu().numpy()
decoder_output = None if decoder_output is None else decoder_output[
0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()
stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy()
return postnet_output, decoder_output, alignment, stop_tokens
def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
postnet_output = postnet_output[0].numpy()
decoder_output = decoder_output[0].numpy()
@ -200,8 +190,8 @@ def speaker_id_to_torch(speaker_id, cuda=False):
def embedding_to_torch(x_vector, cuda=False):
if x_vector is not None:
x_vector = np.asarray(x_vector)
x_vector = torch.from_numpy(x_vector).unsqueeze(
0).type(torch.FloatTensor)
x_vector = torch.from_numpy(x_vector).unsqueeze(0).type(
torch.FloatTensor)
if cuda:
return x_vector.cuda()
return x_vector
@ -263,57 +253,59 @@ def synthesis(
else:
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
# preprocess the given text
inputs = text_to_seq(text, CONFIG)
text_inputs = text_to_seq(text, CONFIG)
# pass tensors to backend
if backend == "torch":
if speaker_id is not None:
speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
if x_vector is not None:
x_vector = embedding_to_torch(x_vector,
cuda=use_cuda)
x_vector = embedding_to_torch(x_vector, cuda=use_cuda)
if not isinstance(style_mel, dict):
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
inputs = inputs.unsqueeze(0)
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = text_inputs.unsqueeze(0)
elif backend == "tf":
# TODO: handle speaker id for tf model
style_mel = numpy_to_tf(style_mel, tf.float32)
inputs = numpy_to_tf(inputs, tf.int32)
inputs = tf.expand_dims(inputs, 0)
text_inputs = numpy_to_tf(text_inputs, tf.int32)
text_inputs = tf.expand_dims(text_inputs, 0)
elif backend == "tflite":
style_mel = numpy_to_tf(style_mel, tf.float32)
inputs = numpy_to_tf(inputs, tf.int32)
inputs = tf.expand_dims(inputs, 0)
text_inputs = numpy_to_tf(text_inputs, tf.int32)
text_inputs = tf.expand_dims(text_inputs, 0)
# synthesize voice
if backend == "torch":
outputs = run_model_torch(model,
inputs,
text_inputs,
speaker_id,
style_mel,
x_vector=x_vector)
postnet_output, decoder_output, alignments, stop_tokens = \
outputs['postnet_outputs'], outputs['decoder_outputs'],\
outputs['alignments'], outputs['stop_tokens']
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
postnet_output, decoder_output, alignments, stop_tokens)
model_outputs = outputs['model_outputs']
model_outputs = model_outputs[0].data.cpu().numpy()
elif backend == "tf":
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
model, inputs, CONFIG, speaker_id, style_mel)
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
model, text_inputs, CONFIG, speaker_id, style_mel)
model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf(
postnet_output, decoder_output, alignments, stop_tokens)
elif backend == "tflite":
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
model, inputs, CONFIG, speaker_id, style_mel)
postnet_output, decoder_output = parse_outputs_tflite(
model, text_inputs, CONFIG, speaker_id, style_mel)
model_outputs, decoder_output = parse_outputs_tflite(
postnet_output, decoder_output)
# convert outputs to numpy
# plot results
wav = None
if use_griffin_lim:
wav = inv_spectrogram(postnet_output, ap, CONFIG)
wav = inv_spectrogram(model_outputs, ap, CONFIG)
# trim silence
if do_trim_silence:
wav = trim_silence(wav, ap)
return wav, alignment, decoder_output, postnet_output, stop_tokens, inputs
return_dict = {
'wav': wav,
'alignments': outputs['alignments'],
'model_outputs': model_outputs,
'text_inputs': text_inputs
}
return return_dict