diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 239e2057..928a2a46 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -6,7 +6,7 @@ import pysbd import torch from TTS.tts.utils.generic_utils import setup_model -from TTS.tts.utils.speakers import load_speaker_mapping +from TTS.tts.utils.speakers import SpeakerManager # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import @@ -49,7 +49,10 @@ class Synthesizer(object): self.vocoder_checkpoint = vocoder_checkpoint self.vocoder_config = vocoder_config self.use_cuda = use_cuda + + self.tts_model = None self.vocoder_model = None + self.speaker_manager = None self.num_speakers = 0 self.tts_speakers = {} self.speaker_embedding_dim = 0 @@ -70,25 +73,10 @@ class Synthesizer(object): def _load_speakers(self, speaker_file: str) -> None: print("Loading speakers ...") - self.tts_speakers = load_speaker_mapping(speaker_file) - self.num_speakers = len(self.tts_speakers) - self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"]) - - def _load_speaker_embedding(self, speaker_json_key: str = ""): - - speaker_embedding = None - - if not speaker_json_key: - raise ValueError(" [!] While 'use_external_speaker_embedding_file', you must pass a 'speaker_json_key'") - - if speaker_json_key != "": - assert self.tts_speakers - assert ( - speaker_json_key in self.tts_speakers - ), f" [!] speaker_json_key is not in self.tts_speakers keys : '{speaker_json_key}'" - speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"] - - return speaker_embedding + self.speaker_manager = SpeakerManager() + self.speaker_manager.load_x_vectors_file(self.tts_config.get("external_speaker_embedding_file", speaker_file)) + self.num_speakers = self.speaker_manager.num_speakers + self.speaker_embedding_dim = self.speaker_manager.x_vector_dim def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: # pylint: disable=global-statement @@ -108,7 +96,7 @@ class Synthesizer(object): self.input_size = len(symbols) if self.tts_config.use_speaker_embedding is True: - self._load_speakers(self.tts_config.get("external_speaker_embedding_file", self.tts_speakers_file)) + self._load_speakers(self.tts_speakers_file) self.tts_model = setup_model( self.input_size, @@ -116,7 +104,6 @@ class Synthesizer(object): c=self.tts_config, speaker_embedding_dim=self.speaker_embedding_dim, ) - self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) if use_cuda: self.tts_model.cuda() @@ -136,13 +123,17 @@ class Synthesizer(object): wav = np.array(wav) self.ap.save_wav(wav, path, self.output_sample_rate) - def tts(self, text: str, speaker_json_key: str = "", style_wav=None) -> List[int]: + def tts(self, text: str, speaker_idx: str = "", style_wav=None) -> List[int]: start_time = time.time() wavs = [] sens = self._split_into_sentences(text) print(" > Text splitted to sentences.") print(sens) - speaker_embedding = self._load_speaker_embedding(speaker_json_key) + + if speaker_idx and isinstance(speaker_idx, str): + speaker_embedding = self.speaker_manager.get_x_vectors_by_speaker(speaker_idx)[0] + else: + speaker_embedding = None use_gl = self.vocoder_model is None