diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index ae24a99e..23be6177 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -19,6 +19,7 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.models.hifigan_generator import HifiganGenerator @@ -280,19 +281,15 @@ class Vits(BaseTTS): language_manager: LanguageManager = None, ): - super().__init__(config) + super().__init__(config, ap, tokenizer, speaker_manager) self.END2END = True self.speaker_manager = speaker_manager self.language_manager = language_manager if config.__class__.__name__ == "VitsConfig": # loading from VitsConfig - if "num_chars" not in config: - _, self.config, num_chars = self.get_characters(config) - config.model_args.num_chars = num_chars - else: - self.config = config - config.model_args.num_chars = config.num_chars + self.num_chars = self.tokenizer.characters.num_chars + self.config = config args = self.config.model_args elif isinstance(config, VitsArgs): # loading from VitsArgs @@ -1039,3 +1036,25 @@ class Vits(BaseTTS): if eval: self.eval() assert not self.training + + @staticmethod + def init_from_config(config: "Coqpit"): + """Initialize model from config.""" + + # init characters + if config.use_phonemes: + from TTS.tts.utils.text.characters import IPAPhonemes + + characters = IPAPhonemes().init_from_config(config) + else: + from TTS.tts.utils.text.characters import Graphemes + + characters = Graphemes().init_from_config(config) + config.num_chars = characters.num_chars + + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config) + return Vits(config, ap, tokenizer, speaker_manager)