diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 9f4d70c8..d4044c7e 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -72,35 +72,21 @@ class BaseTTS(BaseModel): def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: return get_speaker_manager(config, restore_path, data, out_path) - def init_multispeaker(self, config: Coqpit, data: List = None): - """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining - `in_channels` size of the connected layers. - - This implementation yields 3 possible outcomes: - - 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. - 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. - 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of - `config.d_vector_dim` or 512. - - You can override this function for new models.0 + def init_multispeaker(self, config: Coqpit): + """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding + vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension. Args: config (Coqpit): Model configuration. - data (List, optional): Dataset items to infer number of speakers. Defaults to None. """ # init speaker manager - self.speaker_manager = get_speaker_manager(config, data=data) + if self.speaker_manager is None: + raise ValueError(" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model.") + + print(f" > Number of speakers : {len(self.speaker_manager.speaker_ids)}") # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager - if data is not None or self.speaker_manager.speaker_ids: - self.num_speakers = self.speaker_manager.num_speakers - else: - self.num_speakers = ( - config.num_speakers - if "num_speakers" in config and config.num_speakers != 0 - else self.speaker_manager.num_speakers - ) + self.num_speakers = self.speaker_manager.num_speakers # set ultimate speaker embedding size if config.use_speaker_embedding or config.use_d_vector_file: @@ -109,6 +95,7 @@ class BaseTTS(BaseModel): ) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) @@ -345,3 +332,17 @@ class BaseTTS(BaseModel): outputs_dict["outputs"]["alignments"], output_fig=False ) return test_figures, test_audios + + def on_init_start(self, trainer): + """Save the speaker.json at the beginning of the training. And update the config.json with the + speakers.json file path.""" + if self.speaker_manager is not None: + output_path = os.path.join(trainer.output_path, "speakers.json") + self.speaker_manager.save_speaker_ids_to_file(output_path) + trainer.config.speakers_file = output_path + # some models don't have `model_args` set + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.speakers_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `speakers.json` is saved to {output_path}.") + print(" > `speakers_file` is updated in the config.json.")