diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 8bb7f02e..01f4a1de 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -14,6 +14,7 @@ from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.symbols import Graphemes, make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file @@ -32,9 +33,7 @@ class BaseTTS(BaseModel): - 1D tensors `batch x 1` """ - def __init__( - self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None - ): + def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None): super().__init__(config) self.config = config self.ap = ap @@ -292,7 +291,7 @@ class BaseTTS(BaseModel): verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, - tokenizer=self.tokenizer, + tokenizer=self.tokenizer ) # pre-compute phonemes diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index 4762a77a..fe4a9d9b 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -71,7 +71,12 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples ) # AND... 3,2,1... 🚀