From 9f2d2d2081efb6b6aa3c023e4f7b4ef2beb84736 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 7 May 2021 17:27:05 +0200 Subject: [PATCH] add speaker encoder train test --- tests/test_speaker_encoder_train.py | 46 +++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/test_speaker_encoder_train.py diff --git a/tests/test_speaker_encoder_train.py b/tests/test_speaker_encoder_train.py new file mode 100644 index 00000000..0bf04966 --- /dev/null +++ b/tests/test_speaker_encoder_train.py @@ -0,0 +1,46 @@ +import glob +import os +import shutil + +from tests import get_tests_output_path, run_cli +from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig +from TTS.config.shared_configs import BaseAudioConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = SpeakerEncoderConfig( + batch_size=4, + num_speakers_in_batch=1, + num_utters_per_speaker=10, + num_loader_workers=0, + max_train_step=10, + print_step=1, + save_step=10, + print_eval=True, + audio=BaseAudioConfig(num_mels=40) +) +config.audio.do_trim_silence = True +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='' python TTS/bin/train_encoder.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='' python TTS/bin/train_encoder.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)