diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py index 662c9c9f..6115c552 100644 --- a/TTS/demos/xtts_ft_demo/xtts_demo.py +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -10,6 +10,7 @@ import numpy as np import os import torch import torchaudio +import traceback from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt @@ -22,11 +23,12 @@ def clear_gpu_cache(): if torch.cuda.is_available(): torch.cuda.empty_cache() - XTTS_MODEL = None def load_model(xtts_checkpoint, xtts_config, xtts_vocab): - clear_gpu_cache() global XTTS_MODEL + clear_gpu_cache() + if not xtts_checkpoint or not xtts_config or not xtts_vocab: + return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" config = XttsConfig() config.load_json(xtts_config) XTTS_MODEL = Xtts.init_from_config(config) @@ -39,6 +41,9 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab): return "Model Loaded!" def run_tts(lang, tts_text, speaker_audio_file): + if XTTS_MODEL is None or not speaker_audio_file: + return "You need to run the previous step to load the model !!", None, None + gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) out = XTTS_MODEL.inference( text=tts_text, @@ -57,7 +62,7 @@ def run_tts(lang, tts_text, speaker_audio_file): out_path = fp.name torchaudio.save(out_path, out["wav"], 24000) - return out_path, speaker_audio_file + return "Speech generated !", out_path, speaker_audio_file @@ -197,8 +202,7 @@ if __name__ == "__main__": out_path = os.path.join(out_path, "dataset") os.makedirs(out_path, exist_ok=True) if audio_path is None: - # ToDo: raise an error - pass + return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" else: train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) @@ -208,7 +212,7 @@ if __name__ == "__main__": if audio_total_size < 120: message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" print(message) - return message, " ", " " + return message, "", "" print("Dataset Processed!") return "Dataset Processed!", train_meta, eval_meta @@ -253,8 +257,14 @@ if __name__ == "__main__": def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path): clear_gpu_cache() + if not train_csv or not eval_csv: + return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" + try: + config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path) + except Exception as e: + traceback.print_exc() + return f"The training was interrupted due an error !! Please check the console to check the error message! Error summary: {e}", "", "", "", "" - config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path) # copy original files to avoid parameters changes issues os.system(f"cp {config_path} {exp_path}") os.system(f"cp {vocab_file} {exp_path}") @@ -276,8 +286,9 @@ if __name__ == "__main__": label="XTTS config path:", value="", ) + xtts_vocab = gr.Textbox( - label="XTTS config path:", + label="XTTS vocab path:", value="", ) progress_load = gr.Label( @@ -319,6 +330,9 @@ if __name__ == "__main__": tts_btn = gr.Button(value="Step 4 - Inference") with gr.Column() as col3: + progress_gen = gr.Label( + label="Progress:" + ) tts_output_audio = gr.Audio(label="Generated Audio.") reference_audio = gr.Audio(label="Reference audio used.") @@ -368,7 +382,7 @@ if __name__ == "__main__": tts_text, speaker_reference_audio, ], - outputs=[tts_output_audio, reference_audio], + outputs=[progress_gen, tts_output_audio, reference_audio], ) demo.launch(