Update multi-speaker init in BaseTTS

pull/887/head
Eren Gölge 2021-10-18 08:54:41 +00:00
parent a0a5d580e9
commit 127571423c
1 changed files with 23 additions and 22 deletions

View File

@ -72,35 +72,21 @@ class BaseTTS(BaseModel):
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path)
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
This implementation yields 3 possible outcomes:
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
`config.d_vector_dim` or 512.
You can override this function for new models.0
def init_multispeaker(self, config: Coqpit):
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
# init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data)
if self.speaker_manager is None:
raise ValueError(" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model.")
print(f" > Number of speakers : {len(self.speaker_manager.speaker_ids)}")
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
if data is not None or self.speaker_manager.speaker_ids:
self.num_speakers = self.speaker_manager.num_speakers
else:
self.num_speakers = (
config.num_speakers
if "num_speakers" in config and config.num_speakers != 0
else self.speaker_manager.num_speakers
)
self.num_speakers = self.speaker_manager.num_speakers
# set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file:
@ -109,6 +95,7 @@ class BaseTTS(BaseModel):
)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
@ -345,3 +332,17 @@ class BaseTTS(BaseModel):
outputs_dict["outputs"]["alignments"], output_fig=False
)
return test_figures, test_audios
def on_init_start(self, trainer):
"""Save the speaker.json at the beginning of the training. And update the config.json with the
speakers.json file path."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.json")
self.speaker_manager.save_speaker_ids_to_file(output_path)
trainer.config.speakers_file = output_path
# some models don't have `model_args` set
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.json` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")