diff --git a/TTS/model.py b/TTS/model.py index 532d05a6..a7c64dde 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -22,10 +22,13 @@ class BaseModel(nn.Module, ABC): def __init__(self, config: Coqpit): super().__init__() - self._set_model_args(config) - def _set_model_args(self, config: Coqpit): - """Set model arguments from the config. Override this.""" + @staticmethod + def init_from_config(config: Coqpit): + """Init the model from given config. + + Override this depending on your model. + """ pass @abstractmethod diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index e52cd765..45cae79e 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -15,7 +15,7 @@ from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text import make_symbols +from TTS.tts.utils.text.symbols import Graphemes, make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file @@ -34,8 +34,20 @@ class BaseTTS(BaseModel): - 1D tensors `batch x 1` """ + def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None): + super().__init__(config) + self.config = config + self.ap = ap + self.tokenizer = tokenizer + self.speaker_manager = speaker_manager + self._set_model_args(config) + def _set_model_args(self, config: Coqpit): - """Setup model args based on the config type. + """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). + + `ModelArgs` has all the fields reuqired to initialize the model architecture. + + `ModelConfig` has all the fields required for training, inference and containes `ModelArgs`. If the config is for training with a name like "*Config", then the model args are embeded in the config.model_args @@ -44,8 +56,8 @@ class BaseTTS(BaseModel): """ # don't use isintance not to import recursively if "Config" in config.__class__.__name__: + num_chars = self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars if "characters" in config: - _, self.config, num_chars = self.get_characters(config) self.config.num_chars = num_chars if hasattr(self.config, "model_args"): config.model_args.num_chars = num_chars @@ -58,18 +70,21 @@ class BaseTTS(BaseModel): else: raise ValueError("config must be either a *Config or *Args") - @staticmethod - def get_characters(config: Coqpit) -> str: - # TODO: implement CharacterProcessor - if config.characters is not None: - symbols, phonemes = make_symbols(**config.characters) - else: - from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols + # @staticmethod + # def get_characters(config: Coqpit) -> str: + # # TODO: implement CharacterProcessor + # if config.characters is not None: + # symbols, phonemes = make_symbols(**config.characters) + # else: + # from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols - config.characters = CharactersConfig(**parse_symbols()) - model_characters = phonemes if config.use_phonemes else symbols - num_chars = len(model_characters) + getattr(config, "add_blank", False) - return model_characters, config, num_chars + # if config.use_phonemes: + + # config.characters = Graphemes() + + # model_characters = phonemes if config.use_phonemes else symbols + # num_chars = len(model_characters) + getattr(config, "add_blank", False) + # return model_characters, config, num_chars def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: return get_speaker_manager(config, restore_path, data, out_path) @@ -247,8 +262,6 @@ class BaseTTS(BaseModel): if is_eval and not config.run_eval: loader = None else: - ap = assets["audio_processor"] - # setup multi-speaker attributes if hasattr(self, "speaker_manager") and self.speaker_manager is not None: if hasattr(config, "model_args"): @@ -279,28 +292,21 @@ class BaseTTS(BaseModel): # init dataloader dataset = TTSDataset( outputs_per_step=config.r if "r" in config else 1, - text_cleaner=config.text_cleaner, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_f0=config.get("compute_f0", False), f0_cache_path=config.get("f0_cache_path", None), meta_data=data_items, - ap=ap, - characters=config.characters, - custom_symbols=custom_symbols, - add_blank=config["add_blank"], + ap=self.ap, return_wav=config.return_wav if "return_wav" in config else False, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, min_seq_len=config.min_seq_len, max_seq_len=config.max_seq_len, phoneme_cache_path=config.phoneme_cache_path, - use_phonemes=config.use_phonemes, - phoneme_language=config.phoneme_language, - enable_eos_bos=config.enable_eos_bos_chars, use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, speaker_id_mapping=speaker_id_mapping, - d_vector_mapping=d_vector_mapping, - language_id_mapping=language_id_mapping, + d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + tokenizer=self.tokenizer ) # pre-compute phonemes @@ -332,7 +338,7 @@ class BaseTTS(BaseModel): if config.compute_f0 and rank in [None, 0]: if not os.path.exists(config.f0_cache_path): dataset.pitch_extractor.compute_pitch( - ap, config.get("f0_cache_path", None), config.num_loader_workers + self.ap, config.get("f0_cache_path", None), config.num_loader_workers ) # halt DDP processes for the main process to finish computing the F0 cache @@ -404,6 +410,7 @@ class BaseTTS(BaseModel): Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ ap = assets["audio_processor"] + tokenizer = assets["tokenizer"] print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} @@ -416,6 +423,7 @@ class BaseTTS(BaseModel): self.config, "cuda" in str(next(self.parameters()).device), ap, + tokenizer, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 8f3b3804..907f3846 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -46,11 +46,9 @@ class GlowTTS(BaseTTS): """ - def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None): + def __init__(self, config: GlowTTSConfig, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None): - super().__init__(config) - - self.speaker_manager = speaker_manager + super().__init__(config, ap, tokenizer, speaker_manager) # pass all config fields to `self` # for fewer code change @@ -58,7 +56,7 @@ class GlowTTS(BaseTTS): for key in config: setattr(self, key, config[key]) - _, self.config, self.num_chars = self.get_characters(config) + self.num_chars = self.tokenizer.characters.num_chars self.decoder_output_dim = config.out_channels # init multi-speaker layers if necessary @@ -448,7 +446,6 @@ class GlowTTS(BaseTTS): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - ap = assets["audio_processor"] print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} @@ -463,7 +460,8 @@ class GlowTTS(BaseTTS): sen, self.config, "cuda" in str(next(self.parameters()).device), - ap, + self.ap, + self.tokenizer, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index 7bd9ea19..fe4a9d9b 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -11,6 +11,7 @@ from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.models.glow_tts import GlowTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor # we use the same path as this script as our training folder. @@ -47,7 +48,11 @@ config = GlowTTSConfig( # INITIALIZE THE AUDIO PROCESSOR # Audio processor is used for feature extraction and audio I/O. # It mainly serves to the dataloader and the training loggers. -ap = AudioProcessor(**config.audio.to_dict()) +ap = AudioProcessor.init_from_config(config) + +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +tokenizer = TTSTokenizer.init_from_config(config) # LOAD DATA SAMPLES # Each sample is a list of ```[text, audio_file_path, speaker_name]``` @@ -60,7 +65,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) # Models take a config object and a speaker manager as input # Config defines the details of the model like the number of layers, the size of the embedding, etc. # Speaker manager is used by multi-speaker models. -model = GlowTTS(config, speaker_manager=None) +model = GlowTTS(config, ap, tokenizer, speaker_manager=None) # INITIALIZE THE TRAINER # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, @@ -71,8 +76,7 @@ trainer = Trainer( output_path, model=model, train_samples=train_samples, - eval_samples=eval_samples, - training_assets={"audio_processor": ap}, # assets are objetcs used by the models but not class members. + eval_samples=eval_samples ) # AND... 3,2,1... 🚀