mirror of https://github.com/coqui-ai/TTS.git
update `synthesis.py` for the trainer
parent
130781dab6
commit
f4f83b6379
|
@ -34,10 +34,6 @@ def save_speaker_mapping(out_path, speaker_mapping):
|
|||
json.dump(speaker_mapping, f, indent=4)
|
||||
|
||||
|
||||
def get_speakers(items):
|
||||
|
||||
|
||||
|
||||
def parse_speakers(c, args, meta_data_train, OUT_PATH):
|
||||
"""Returns number of speakers, speaker embedding shape and speaker mapping"""
|
||||
if c.use_speaker_embedding:
|
||||
|
@ -135,7 +131,7 @@ class SpeakerManager:
|
|||
):
|
||||
|
||||
self.data_items = []
|
||||
self.x_vectors = []
|
||||
self.x_vectors = {}
|
||||
self.speaker_ids = []
|
||||
self.clip_ids = []
|
||||
self.speaker_encoder = None
|
||||
|
@ -171,7 +167,7 @@ class SpeakerManager:
|
|||
def x_vector_dim(self):
|
||||
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])
|
||||
|
||||
def parser_speakers_from_items(self, items: list):
|
||||
def parse_speakers_from_items(self, items: list):
|
||||
speakers = sorted({item[2] for item in items})
|
||||
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||
num_speakers = len(self.speaker_ids)
|
||||
|
|
|
@ -13,7 +13,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def text_to_seqvec(text, CONFIG):
|
||||
def text_to_seq(text, CONFIG):
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
# text ot phonemes to sequence vector
|
||||
if CONFIG.use_phonemes:
|
||||
|
@ -59,81 +59,82 @@ def numpy_to_tf(np_array, dtype):
|
|||
|
||||
|
||||
def compute_style_mel(style_wav, ap, cuda=False):
|
||||
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
|
||||
style_mel = torch.FloatTensor(
|
||||
ap.melspectrogram(ap.load_wav(style_wav,
|
||||
sr=ap.sample_rate))).unsqueeze(0)
|
||||
if cuda:
|
||||
return style_mel.cuda()
|
||||
return style_mel
|
||||
|
||||
|
||||
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
|
||||
if "tacotron" in CONFIG.model.lower():
|
||||
if CONFIG.gst:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
elif "glow" in CONFIG.model.lower():
|
||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
if hasattr(model, "module"):
|
||||
# distributed model
|
||||
postnet_output, _, _, _, alignments, _, _ = model.module.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
else:
|
||||
postnet_output, _, _, _, alignments, _, _ = model.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
decoder_output = None
|
||||
stop_tokens = None
|
||||
elif CONFIG.model.lower() in ["speedy_speech", "align_tts"]:
|
||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
if hasattr(model, "module"):
|
||||
# distributed model
|
||||
postnet_output, alignments = model.module.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
else:
|
||||
postnet_output, alignments = model.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
decoder_output = None
|
||||
stop_tokens = None
|
||||
else:
|
||||
raise ValueError("[!] Unknown model name.")
|
||||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
def run_model_torch(model,
|
||||
inputs,
|
||||
speaker_id=None,
|
||||
style_mel=None,
|
||||
x_vector=None):
|
||||
outputs = model.inference(inputs,
|
||||
cond_input={
|
||||
'speaker_ids': speaker_id,
|
||||
'x_vector': x_vector,
|
||||
'style_mel': style_mel
|
||||
})
|
||||
# elif "glow" in CONFIG.model.lower():
|
||||
# inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
# if hasattr(model, "module"):
|
||||
# # distributed model
|
||||
# postnet_output, _, _, _, alignments, _, _ = model.module.inference(
|
||||
# inputs,
|
||||
# inputs_lengths,
|
||||
# g=speaker_id if speaker_id is not None else speaker_embeddings)
|
||||
# else:
|
||||
# postnet_output, _, _, _, alignments, _, _ = model.inference(
|
||||
# inputs,
|
||||
# inputs_lengths,
|
||||
# g=speaker_id if speaker_id is not None else speaker_embeddings)
|
||||
# postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# # these only belong to tacotron models.
|
||||
# decoder_output = None
|
||||
# stop_tokens = None
|
||||
# elif CONFIG.model.lower() in ["speedy_speech", "align_tts"]:
|
||||
# inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
# if hasattr(model, "module"):
|
||||
# # distributed model
|
||||
# postnet_output, alignments = model.module.inference(
|
||||
# inputs,
|
||||
# inputs_lengths,
|
||||
# g=speaker_id if speaker_id is not None else speaker_embeddings)
|
||||
# else:
|
||||
# postnet_output, alignments = model.inference(
|
||||
# inputs,
|
||||
# inputs_lengths,
|
||||
# g=speaker_id if speaker_id is not None else speaker_embeddings)
|
||||
# postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# # these only belong to tacotron models.
|
||||
# decoder_output = None
|
||||
# stop_tokens = None
|
||||
# else:
|
||||
# raise ValueError("[!] Unknown model name.")
|
||||
return outputs
|
||||
|
||||
|
||||
def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
||||
if CONFIG.gst and style_mel is not None:
|
||||
raise NotImplementedError(" [!] GST inference not implemented for TF")
|
||||
if truncated:
|
||||
raise NotImplementedError(" [!] Truncated inference not implemented for TF")
|
||||
if speaker_id is not None:
|
||||
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
|
||||
# TODO: handle multispeaker case
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False)
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
inputs, training=False)
|
||||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
|
||||
|
||||
def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
||||
if CONFIG.gst and style_mel is not None:
|
||||
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
|
||||
if truncated:
|
||||
raise NotImplementedError(" [!] Truncated inference not implemented for TfLite")
|
||||
raise NotImplementedError(
|
||||
" [!] GST inference not implemented for TfLite")
|
||||
if speaker_id is not None:
|
||||
raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite")
|
||||
raise NotImplementedError(
|
||||
" [!] Multi-Speaker not implemented for TfLite")
|
||||
# get input and output details
|
||||
input_details = model.get_input_details()
|
||||
output_details = model.get_output_details()
|
||||
|
@ -152,9 +153,11 @@ def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_me
|
|||
return decoder_output, postnet_output, None, None
|
||||
|
||||
|
||||
def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens):
|
||||
def parse_outputs_torch(postnet_output, decoder_output, alignments,
|
||||
stop_tokens):
|
||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||
decoder_output = None if decoder_output is None else decoder_output[0].data.cpu().numpy()
|
||||
decoder_output = None if decoder_output is None else decoder_output[
|
||||
0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy()
|
||||
return postnet_output, decoder_output, alignment, stop_tokens
|
||||
|
@ -175,7 +178,7 @@ def parse_outputs_tflite(postnet_output, decoder_output):
|
|||
|
||||
|
||||
def trim_silence(wav, ap):
|
||||
return wav[: ap.find_endpoint(wav)]
|
||||
return wav[:ap.find_endpoint(wav)]
|
||||
|
||||
|
||||
def inv_spectrogram(postnet_output, ap, CONFIG):
|
||||
|
@ -186,23 +189,23 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
|||
return wav
|
||||
|
||||
|
||||
def id_to_torch(speaker_id, cuda=False):
|
||||
def speaker_id_to_torch(speaker_id, cuda=False):
|
||||
if speaker_id is not None:
|
||||
speaker_id = np.asarray(speaker_id)
|
||||
# TODO: test this for tacotron models
|
||||
speaker_id = torch.from_numpy(speaker_id)
|
||||
if cuda:
|
||||
return speaker_id.cuda()
|
||||
return speaker_id
|
||||
|
||||
|
||||
def embedding_to_torch(speaker_embedding, cuda=False):
|
||||
if speaker_embedding is not None:
|
||||
speaker_embedding = np.asarray(speaker_embedding)
|
||||
speaker_embedding = torch.from_numpy(speaker_embedding).unsqueeze(0).type(torch.FloatTensor)
|
||||
def embedding_to_torch(x_vector, cuda=False):
|
||||
if x_vector is not None:
|
||||
x_vector = np.asarray(x_vector)
|
||||
x_vector = torch.from_numpy(x_vector).unsqueeze(
|
||||
0).type(torch.FloatTensor)
|
||||
if cuda:
|
||||
return speaker_embedding.cuda()
|
||||
return speaker_embedding
|
||||
return x_vector.cuda()
|
||||
return x_vector
|
||||
|
||||
|
||||
# TODO: perform GL with pytorch for batching
|
||||
|
@ -216,7 +219,8 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
|
|||
"""
|
||||
wavs = []
|
||||
for idx, spec in enumerate(inputs):
|
||||
wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding
|
||||
wav_len = (input_lens[idx] *
|
||||
ap.hop_length) - ap.hop_length # inverse librosa padding
|
||||
wav = inv_spectrogram(spec, ap, CONFIG)
|
||||
# assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}"
|
||||
wavs.append(wav[:wav_len])
|
||||
|
@ -231,11 +235,10 @@ def synthesis(
|
|||
ap,
|
||||
speaker_id=None,
|
||||
style_wav=None,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=False, # pylint: disable=unused-argument
|
||||
use_griffin_lim=False,
|
||||
do_trim_silence=False,
|
||||
speaker_embedding=None,
|
||||
x_vector=None,
|
||||
backend="torch",
|
||||
):
|
||||
"""Synthesize voice for the given text.
|
||||
|
@ -249,8 +252,6 @@ def synthesis(
|
|||
model outputs.
|
||||
speaker_id (int): id of speaker
|
||||
style_wav (str | Dict[str, float]): Uses for style embedding of GST.
|
||||
truncated (bool): keep model states after inference. It can be used
|
||||
for continuous inference at long texts.
|
||||
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
||||
do_trim_silence (bool): trim silence after synthesis.
|
||||
backend (str): tf or torch
|
||||
|
@ -263,14 +264,15 @@ def synthesis(
|
|||
else:
|
||||
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
||||
# preprocess the given text
|
||||
inputs = text_to_seqvec(text, CONFIG)
|
||||
inputs = text_to_seq(text, CONFIG)
|
||||
# pass tensors to backend
|
||||
if backend == "torch":
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||
speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
|
||||
|
||||
if speaker_embedding is not None:
|
||||
speaker_embedding = embedding_to_torch(speaker_embedding, cuda=use_cuda)
|
||||
if x_vector is not None:
|
||||
x_vector = embedding_to_torch(x_vector,
|
||||
cuda=use_cuda)
|
||||
|
||||
if not isinstance(style_mel, dict):
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
|
@ -287,24 +289,26 @@ def synthesis(
|
|||
inputs = tf.expand_dims(inputs, 0)
|
||||
# synthesize voice
|
||||
if backend == "torch":
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding
|
||||
)
|
||||
outputs = run_model_torch(model,
|
||||
inputs,
|
||||
speaker_id,
|
||||
style_mel,
|
||||
x_vector=x_vector)
|
||||
postnet_output, decoder_output, alignments, stop_tokens = \
|
||||
outputs['postnet_outputs'], outputs['decoder_outputs'],\
|
||||
outputs['alignments'], outputs['stop_tokens']
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
|
||||
postnet_output, decoder_output, alignments, stop_tokens
|
||||
)
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
elif backend == "tf":
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel
|
||||
)
|
||||
model, inputs, CONFIG, speaker_id, style_mel)
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||
postnet_output, decoder_output, alignments, stop_tokens
|
||||
)
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
elif backend == "tflite":
|
||||
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel
|
||||
)
|
||||
postnet_output, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
|
||||
model, inputs, CONFIG, speaker_id, style_mel)
|
||||
postnet_output, decoder_output = parse_outputs_tflite(
|
||||
postnet_output, decoder_output)
|
||||
# convert outputs to numpy
|
||||
# plot results
|
||||
wav = None
|
||||
|
|
Loading…
Reference in New Issue