diff --git a/TTS/bin/train_align_tts.py b/TTS/bin/train_align_tts.py index f5658dd2..d231484a 100644 --- a/TTS/bin/train_align_tts.py +++ b/TTS/bin/train_align_tts.py @@ -229,7 +229,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, if global_step % config.tb_plot_step == 0: iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats.update(loss_dict) - tb_logger.tb_train_iter_stats(global_step, iter_stats) + tb_logger.tb_train_step_stats(global_step, iter_stats) if global_step % config.save_step == 0: if config.checkpoint: diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 50e95a2b..9a455a1b 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -270,7 +270,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, if global_step % config.tb_plot_step == 0: iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats.update(loss_dict) - tb_logger.tb_train_iter_stats(global_step, iter_stats) + tb_logger.tb_train_step_stats(global_step, iter_stats) if global_step % config.save_step == 0: if config.checkpoint: diff --git a/TTS/bin/train_speedy_speech.py b/TTS/bin/train_speedy_speech.py index 4ab0c899..742a27d8 100644 --- a/TTS/bin/train_speedy_speech.py +++ b/TTS/bin/train_speedy_speech.py @@ -256,7 +256,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, if global_step % config.tb_plot_step == 0: iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats.update(loss_dict) - tb_logger.tb_train_iter_stats(global_step, iter_stats) + tb_logger.tb_train_step_stats(global_step, iter_stats) if global_step % config.save_step == 0: if config.checkpoint: diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index 098a8d3f..b5e38b80 100755 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -327,7 +327,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap, "step_time": step_time, } iter_stats.update(loss_dict) - tb_logger.tb_train_iter_stats(global_step, iter_stats) + tb_logger.tb_train_step_stats(global_step, iter_stats) if global_step % config.save_step == 0: if config.checkpoint: diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 123d5a43..ea317ef6 100755 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -265,7 +265,7 @@ def train( if global_step % 10 == 0: iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time} iter_stats.update(loss_dict) - tb_logger.tb_train_iter_stats(global_step, iter_stats) + tb_logger.tb_train_step_stats(global_step, iter_stats) # save checkpoint if global_step % c.save_step == 0: diff --git a/TTS/bin/train_vocoder_wavegrad.py b/TTS/bin/train_vocoder_wavegrad.py index c0fcff51..c8f067ee 100644 --- a/TTS/bin/train_vocoder_wavegrad.py +++ b/TTS/bin/train_vocoder_wavegrad.py @@ -181,7 +181,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch if global_step % 10 == 0: iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time} iter_stats.update(loss_dict) - tb_logger.tb_train_iter_stats(global_step, iter_stats) + tb_logger.tb_train_step_stats(global_step, iter_stats) # save checkpoint if global_step % c.save_step == 0: diff --git a/TTS/bin/train_vocoder_wavernn.py b/TTS/bin/train_vocoder_wavernn.py index bcad9493..86a1506a 100644 --- a/TTS/bin/train_vocoder_wavernn.py +++ b/TTS/bin/train_vocoder_wavernn.py @@ -163,7 +163,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch if global_step % 10 == 0: iter_stats = {"lr": cur_lr, "step_time": step_time} iter_stats.update(loss_dict) - tb_logger.tb_train_iter_stats(global_step, iter_stats) + tb_logger.tb_train_step_stats(global_step, iter_stats) # save checkpoint if global_step % c.save_step == 0: diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index a501a880..a2d935c7 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -133,6 +133,18 @@ class BaseTTSConfig(BaseTrainingConfig): datasets (List[BaseDatasetConfig]): List of datasets used for training. If multiple datasets are provided, they are merged and used together for training. + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to ``. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to ``. + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + test_sentences (List[str]): + List of sentences to be used at testing. Defaults to '[]' """ audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) @@ -158,3 +170,11 @@ class BaseTTSConfig(BaseTrainingConfig): add_blank: bool = False # dataset datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # optimizer + optimizer: str = MISSING + optimizer_params: dict = MISSING + # scheduler + lr_scheduler: str = '' + lr_scheduler_params: dict = field(default_factory=lambda: {}) + # testing + test_sentences: List[str] = field(default_factory=lambda:[]) diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py index a567cd88..ff8d89bb 100644 --- a/TTS/tts/configs/tacotron_config.py +++ b/TTS/tts/configs/tacotron_config.py @@ -78,10 +78,16 @@ class TacotronConfig(BaseTTSConfig): enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. external_speaker_embedding_file (str): Path to the file including pre-computed speaker embeddings. Defaults to None. - noam_schedule (bool): - enable / disable the use of Noam LR scheduler. Defaults to False. - warmup_steps (int): - Number of warm-up steps for the Noam scheduler. Defaults 4000. + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to `RAdam`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `NoamLR`. + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. lr (float): Initial learning rate. Defaults to `1e-4`. wd (float): @@ -152,10 +158,11 @@ class TacotronConfig(BaseTTSConfig): external_speaker_embedding_file: str = False # optimizer parameters - noam_schedule: bool = False - warmup_steps: int = 4000 + optimizer: str = "RAdam" + optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000}) lr: float = 1e-4 - wd: float = 1e-6 grad_clip: float = 5.0 seq_len_norm: bool = False loss_masking: bool = True diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 84da1f72..4ab78f88 100755 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,7 +1,7 @@ import json import os import random -from typing import Union +from typing import Union, List, Any import numpy as np import torch @@ -35,9 +35,7 @@ def save_speaker_mapping(out_path, speaker_mapping): def get_speakers(items): - """Returns a sorted, unique list of speakers in a given dataset.""" - speakers = {e[2] for e in items} - return sorted(speakers) + def parse_speakers(c, args, meta_data_train, OUT_PATH): @@ -121,26 +119,31 @@ class SpeakerManager: Args: x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "". - speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by the - TTS model. Defaults to "". + speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by + TTS models. Defaults to "". encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "". encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "". """ def __init__( self, + data_items: List[List[Any]] = None, x_vectors_file_path: str = "", speaker_id_file_path: str = "", encoder_model_path: str = "", encoder_config_path: str = "", ): - self.x_vectors = None - self.speaker_ids = None - self.clip_ids = None + self.data_items = [] + self.x_vectors = [] + self.speaker_ids = [] + self.clip_ids = [] self.speaker_encoder = None self.speaker_encoder_ap = None + if data_items: + self.speaker_ids = self.parse_speakers() + if x_vectors_file_path: self.load_x_vectors_file(x_vectors_file_path) @@ -169,10 +172,10 @@ class SpeakerManager: return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"]) def parser_speakers_from_items(self, items: list): - speaker_ids = sorted({item[2] for item in items}) - self.speaker_ids = speaker_ids - num_speakers = len(speaker_ids) - return speaker_ids, num_speakers + speakers = sorted({item[2] for item in items}) + self.speaker_ids = {name: i for i, name in enumerate(speakers)} + num_speakers = len(self.speaker_ids) + return self.speaker_ids, num_speakers def save_ids_file(self, file_path: str): self._save_json(file_path, self.speaker_ids) diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index 3d2caa97..4b041ed8 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -65,7 +65,7 @@ def basic_cleaners(text): def transliteration_cleaners(text): """Pipeline for non-English text that transliterates to ASCII.""" - text = convert_to_ascii(text) + # text = convert_to_ascii(text) text = lowercase(text) text = collapse_whitespace(text) return text @@ -89,7 +89,7 @@ def basic_turkish_cleaners(text): def english_cleaners(text): """Pipeline for English text, including number and abbreviation expansion.""" - text = convert_to_ascii(text) + # text = convert_to_ascii(text) text = lowercase(text) text = expand_time_english(text) text = expand_numbers(text) @@ -129,7 +129,7 @@ def chinese_mandarin_cleaners(text: str) -> str: def phoneme_cleaners(text): """Pipeline for phonemes mode, including number and abbreviation expansion.""" text = expand_numbers(text) - text = convert_to_ascii(text) + # text = convert_to_ascii(text) text = expand_abbreviations(text) text = replace_symbols(text) text = remove_aux_symbols(text) diff --git a/TTS/utils/tensorboard_logger.py b/TTS/utils/tensorboard_logger.py index 3874a42b..657deb5b 100644 --- a/TTS/utils/tensorboard_logger.py +++ b/TTS/utils/tensorboard_logger.py @@ -39,7 +39,7 @@ class TensorboardLogger(object): except RuntimeError: traceback.print_exc() - def tb_train_iter_stats(self, step, stats): + def tb_train_step_stats(self, step, stats): self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step) def tb_train_epoch_stats(self, step, stats): diff --git a/tests/vocoder_tests/test_melgan_train.py b/tests/vocoder_tests/test_melgan_train.py index 3ff65b5a..e3004db7 100644 --- a/tests/vocoder_tests/test_melgan_train.py +++ b/tests/vocoder_tests/test_melgan_train.py @@ -21,6 +21,7 @@ config = MelganConfig( print_step=1, discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]}, print_eval=True, + discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]}, data_path="tests/data/ljspeech", output_path=output_path, )