mirror of https://github.com/coqui-ai/TTS.git
glow_tts_config.py and train test on python
parent
c6df8de80a
commit
51a7e06945
|
@ -0,0 +1,50 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from .shared_configs import BaseTTSConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlowTTSConfig(BaseTTSConfig):
|
||||
"""Defines parameters for GlowTTS model."""
|
||||
|
||||
model: str = "glow_tts"
|
||||
|
||||
# model params
|
||||
encoder_type: str = "rel_pos_transformer"
|
||||
encoder_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"num_heads": 2,
|
||||
"hidden_channels_ffn": 768,
|
||||
}
|
||||
)
|
||||
use_encoder_prenet: bool = True
|
||||
hidden_channels_encoder: int = 192
|
||||
hidden_channels_decoder: int = 192
|
||||
hidden_channels_duration_predictor: int = 256
|
||||
|
||||
# training params
|
||||
data_dep_init_steps: int = 10
|
||||
|
||||
# inference params
|
||||
style_wav_for_test: str = None
|
||||
inference_noise_scale: float = 0.0
|
||||
|
||||
# multi-speaker settings
|
||||
use_speaker_embedding: bool = False
|
||||
use_external_speaker_embedding_file: bool = False
|
||||
external_speaker_embedding_file: str = False
|
||||
|
||||
# optimizer params
|
||||
noam_schedule: bool = True
|
||||
warmup_steps: int = 4000
|
||||
grad_clip: float = 5.0
|
||||
lr: float = 1e-3
|
||||
wd: float = 0.000001
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 3
|
||||
max_seq_len: int = 500
|
||||
r: int = 1
|
|
@ -0,0 +1,49 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from tests import get_tests_output_path, run_cli
|
||||
from TTS.tts.configs import GlowTTSConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
config = GlowTTSConfig(
|
||||
batch_size=8,
|
||||
eval_batch_size=8,
|
||||
num_loader_workers=0,
|
||||
num_val_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='' python TTS/bin/train_glow_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.name ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt"
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='' python TTS/bin/train_glow_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue