glow_tts_config.py and train test on python

pull/476/head
Eren Gölge 2021-05-06 03:41:06 +02:00
parent c6df8de80a
commit 51a7e06945
2 changed files with 99 additions and 0 deletions

View File

@ -0,0 +1,50 @@
from dataclasses import dataclass, field
from .shared_configs import BaseTTSConfig
@dataclass
class GlowTTSConfig(BaseTTSConfig):
"""Defines parameters for GlowTTS model."""
model: str = "glow_tts"
# model params
encoder_type: str = "rel_pos_transformer"
encoder_params: dict = field(
default_factory=lambda: {
"kernel_size": 3,
"dropout_p": 0.1,
"num_layers": 6,
"num_heads": 2,
"hidden_channels_ffn": 768,
}
)
use_encoder_prenet: bool = True
hidden_channels_encoder: int = 192
hidden_channels_decoder: int = 192
hidden_channels_duration_predictor: int = 256
# training params
data_dep_init_steps: int = 10
# inference params
style_wav_for_test: str = None
inference_noise_scale: float = 0.0
# multi-speaker settings
use_speaker_embedding: bool = False
use_external_speaker_embedding_file: bool = False
external_speaker_embedding_file: str = False
# optimizer params
noam_schedule: bool = True
warmup_steps: int = 4000
grad_clip: float = 5.0
lr: float = 1e-3
wd: float = 0.000001
# overrides
min_seq_len: int = 3
max_seq_len: int = 500
r: int = 1

View File

@ -0,0 +1,49 @@
import glob
import os
import shutil
from tests import get_tests_output_path, run_cli
from TTS.tts.configs import GlowTTSConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = GlowTTSConfig(
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,
num_val_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='' python TTS/bin/train_glow_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='' python TTS/bin/train_glow_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)