mirror of https://github.com/coqui-ai/TTS.git
update `synthesis.py` for being more generic
parent
f121b0ff5d
commit
fb9289d365
|
@ -152,16 +152,6 @@ def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
|||
return decoder_output, postnet_output, None, None
|
||||
|
||||
|
||||
def parse_outputs_torch(postnet_output, decoder_output, alignments,
|
||||
stop_tokens):
|
||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||
decoder_output = None if decoder_output is None else decoder_output[
|
||||
0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy()
|
||||
return postnet_output, decoder_output, alignment, stop_tokens
|
||||
|
||||
|
||||
def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
|
||||
postnet_output = postnet_output[0].numpy()
|
||||
decoder_output = decoder_output[0].numpy()
|
||||
|
@ -200,8 +190,8 @@ def speaker_id_to_torch(speaker_id, cuda=False):
|
|||
def embedding_to_torch(x_vector, cuda=False):
|
||||
if x_vector is not None:
|
||||
x_vector = np.asarray(x_vector)
|
||||
x_vector = torch.from_numpy(x_vector).unsqueeze(
|
||||
0).type(torch.FloatTensor)
|
||||
x_vector = torch.from_numpy(x_vector).unsqueeze(0).type(
|
||||
torch.FloatTensor)
|
||||
if cuda:
|
||||
return x_vector.cuda()
|
||||
return x_vector
|
||||
|
@ -263,57 +253,59 @@ def synthesis(
|
|||
else:
|
||||
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
||||
# preprocess the given text
|
||||
inputs = text_to_seq(text, CONFIG)
|
||||
text_inputs = text_to_seq(text, CONFIG)
|
||||
# pass tensors to backend
|
||||
if backend == "torch":
|
||||
if speaker_id is not None:
|
||||
speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
|
||||
|
||||
if x_vector is not None:
|
||||
x_vector = embedding_to_torch(x_vector,
|
||||
cuda=use_cuda)
|
||||
x_vector = embedding_to_torch(x_vector, cuda=use_cuda)
|
||||
|
||||
if not isinstance(style_mel, dict):
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
||||
inputs = inputs.unsqueeze(0)
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
elif backend == "tf":
|
||||
# TODO: handle speaker id for tf model
|
||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||
inputs = numpy_to_tf(inputs, tf.int32)
|
||||
inputs = tf.expand_dims(inputs, 0)
|
||||
text_inputs = numpy_to_tf(text_inputs, tf.int32)
|
||||
text_inputs = tf.expand_dims(text_inputs, 0)
|
||||
elif backend == "tflite":
|
||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||
inputs = numpy_to_tf(inputs, tf.int32)
|
||||
inputs = tf.expand_dims(inputs, 0)
|
||||
text_inputs = numpy_to_tf(text_inputs, tf.int32)
|
||||
text_inputs = tf.expand_dims(text_inputs, 0)
|
||||
# synthesize voice
|
||||
if backend == "torch":
|
||||
outputs = run_model_torch(model,
|
||||
inputs,
|
||||
text_inputs,
|
||||
speaker_id,
|
||||
style_mel,
|
||||
x_vector=x_vector)
|
||||
postnet_output, decoder_output, alignments, stop_tokens = \
|
||||
outputs['postnet_outputs'], outputs['decoder_outputs'],\
|
||||
outputs['alignments'], outputs['stop_tokens']
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
model_outputs = outputs['model_outputs']
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
elif backend == "tf":
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||
model, inputs, CONFIG, speaker_id, style_mel)
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel)
|
||||
model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
elif backend == "tflite":
|
||||
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
||||
model, inputs, CONFIG, speaker_id, style_mel)
|
||||
postnet_output, decoder_output = parse_outputs_tflite(
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel)
|
||||
model_outputs, decoder_output = parse_outputs_tflite(
|
||||
postnet_output, decoder_output)
|
||||
# convert outputs to numpy
|
||||
# plot results
|
||||
wav = None
|
||||
if use_griffin_lim:
|
||||
wav = inv_spectrogram(postnet_output, ap, CONFIG)
|
||||
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
||||
# trim silence
|
||||
if do_trim_silence:
|
||||
wav = trim_silence(wav, ap)
|
||||
return wav, alignment, decoder_output, postnet_output, stop_tokens, inputs
|
||||
return_dict = {
|
||||
'wav': wav,
|
||||
'alignments': outputs['alignments'],
|
||||
'model_outputs': model_outputs,
|
||||
'text_inputs': text_inputs
|
||||
}
|
||||
return return_dict
|
||||
|
|
Loading…
Reference in New Issue