From 51a7e0694586394c1fb8cd28af6da2eea5a0ab26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 6 May 2021 03:41:06 +0200 Subject: [PATCH] glow_tts_config.py and train test on python --- TTS/tts/configs/glow_tts_config.py | 50 ++++++++++++++++++++++++++ tests/tts_tests/test_glow_tts_train.py | 49 +++++++++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 TTS/tts/configs/glow_tts_config.py create mode 100644 tests/tts_tests/test_glow_tts_train.py diff --git a/TTS/tts/configs/glow_tts_config.py b/TTS/tts/configs/glow_tts_config.py new file mode 100644 index 00000000..8474caae --- /dev/null +++ b/TTS/tts/configs/glow_tts_config.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass, field + +from .shared_configs import BaseTTSConfig + + +@dataclass +class GlowTTSConfig(BaseTTSConfig): + """Defines parameters for GlowTTS model.""" + + model: str = "glow_tts" + + # model params + encoder_type: str = "rel_pos_transformer" + encoder_params: dict = field( + default_factory=lambda: { + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + } + ) + use_encoder_prenet: bool = True + hidden_channels_encoder: int = 192 + hidden_channels_decoder: int = 192 + hidden_channels_duration_predictor: int = 256 + + # training params + data_dep_init_steps: int = 10 + + # inference params + style_wav_for_test: str = None + inference_noise_scale: float = 0.0 + + # multi-speaker settings + use_speaker_embedding: bool = False + use_external_speaker_embedding_file: bool = False + external_speaker_embedding_file: str = False + + # optimizer params + noam_schedule: bool = True + warmup_steps: int = 4000 + grad_clip: float = 5.0 + lr: float = 1e-3 + wd: float = 0.000001 + + # overrides + min_seq_len: int = 3 + max_seq_len: int = 500 + r: int = 1 diff --git a/tests/tts_tests/test_glow_tts_train.py b/tests/tts_tests/test_glow_tts_train.py new file mode 100644 index 00000000..bb630aef --- /dev/null +++ b/tests/tts_tests/test_glow_tts_train.py @@ -0,0 +1,49 @@ +import glob +import os +import shutil + +from tests import get_tests_output_path, run_cli +from TTS.tts.configs import GlowTTSConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = GlowTTSConfig( + batch_size=8, + eval_batch_size=8, + num_loader_workers=0, + num_val_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"), + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, +) +config.audio.do_trim_silence = True +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='' python TTS/bin/train_glow_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='' python TTS/bin/train_glow_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)