diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py index 54fade38..d6e2f313 100644 --- a/TTS/demos/xtts_ft_demo/utils/gpt_train.py +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -8,7 +8,7 @@ from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrai from TTS.utils.manage import ModelManager -def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path): +def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995): # Logging parameters RUN_NAME = "GPT_XTTS_FT" PROJECT_NAME = "XTTS_trainer" @@ -79,7 +79,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, max_conditioning_length=132300, # 6 secs min_conditioning_length=66150, # 3 secs debug_loading_failures=False, - max_wav_length=255995, # ~11.6 seconds + max_wav_length=max_audio_length, # ~11.6 seconds max_text_length=200, mel_norm_file=MEL_NORM_FILE, dvae_checkpoint=DVAE_CHECKPOINT, diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py index 43448adc..8e9a88eb 100644 --- a/TTS/demos/xtts_ft_demo/xtts_demo.py +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -147,6 +147,13 @@ if __name__ == "__main__": help="Grad accumulation steps. Default: 1", default=1, ) + parser.add_argument( + "--max_audio_length", + type=int, + help="Max permitted audio size in seconds. Default: 11", + default=11, + ) + args = parser.parse_args() with gr.Blocks() as demo: @@ -250,6 +257,13 @@ if __name__ == "__main__": step=1, value=args.grad_acumm, ) + max_audio_length = gr.Slider( + label="Max permitted audio size in seconds:", + minimum=2, + maximum=20, + step=1, + value=args.max_audio_length, + ) progress_train = gr.Label( label="Progress:" ) @@ -260,12 +274,14 @@ if __name__ == "__main__": demo.load(read_logs, None, logs_tts_train, every=1) train_btn = gr.Button(value="Step 2 - Run the training") - def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path): + def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): clear_gpu_cache() if not train_csv or not eval_csv: return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" try: - config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path) + # convert seconds to waveform frames + max_audio_length = int(max_audio_length * 22050) + config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) except: traceback.print_exc() error = traceback.format_exc() @@ -280,7 +296,6 @@ if __name__ == "__main__": clear_gpu_cache() return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav - with gr.Tab("3 - Inference"): with gr.Row(): with gr.Column() as col1: @@ -367,6 +382,7 @@ if __name__ == "__main__": batch_size, grad_acumm, out_path, + max_audio_length, ], outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], )