use get_speaker_manager in Trainer and save speakers.json file when

needed
pull/506/head
Eren Gölge 2021-06-05 11:46:53 +02:00
parent d6b2b6add6
commit 2c38ef8441
2 changed files with 14 additions and 24 deletions

View File

@ -21,7 +21,7 @@ from TTS.tts.datasets import TTSDataset, load_meta_data
from TTS.tts.layers import setup_loss
from TTS.tts.models import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -186,25 +186,7 @@ class TrainerTTS:
def get_speaker_manager(
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None
) -> SpeakerManager:
speaker_manager = SpeakerManager()
if restore_path:
speakers_file = os.path.join(os.path.dirname(restore_path), "speaker.json")
if not os.path.exists(speakers_file):
print(
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
)
speakers_file = config.external_speaker_embedding_file
if config.use_external_speaker_embedding_file:
speaker_manager.load_d_vectors_file(speakers_file)
else:
speaker_manager.load_ids_file(speakers_file)
elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file:
speaker_manager.load_d_vectors_file(config.external_speaker_embedding_file)
else:
speaker_manager.parse_speakers_from_items(data_train)
file_path = os.path.join(out_path, "speakers.json")
speaker_manager.save_ids_file(file_path)
speaker_manager = get_speaker_manager(config, restore_path, data_train, out_path)
return speaker_manager
@staticmethod

View File

@ -34,16 +34,16 @@ def save_speaker_mapping(out_path, speaker_mapping):
json.dump(speaker_mapping, f, indent=4)
def get_speaker_manager(c, args, meta_data_train):
def get_speaker_manager(c, restore_path, meta_data_train, out_path=None):
"""Inititalize and return a `SpeakerManager` based on config values"""
speaker_manager = SpeakerManager()
if c.use_speaker_embedding:
speaker_manager.set_speaker_ids_from_data(meta_data_train)
if args.restore_path:
if restore_path:
# restoring speaker manager from a previous run.
if c.use_external_speaker_embedding_file:
# restore speaker manager with the embedding file
speakers_file = os.path.dirname(args.restore_path)
speakers_file = os.path.dirname(restore_path)
if not os.path.exists(speakers_file):
print(
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
@ -55,7 +55,7 @@ def get_speaker_manager(c, args, meta_data_train):
speaker_manager.load_d_vectors_file(c.external_speaker_embedding_file)
speaker_manager.set_d_vectors_from_file(speakers_file)
elif not c.use_external_speaker_embedding_file: # restor speaker manager with speaker ID file.
speakers_file = os.path.dirname(args.restore_path)
speakers_file = os.path.dirname(restore_path)
speaker_ids_from_data = speaker_manager.speaker_ids
speaker_manager.set_speaker_ids_from_file(speakers_file)
assert all(
@ -73,6 +73,14 @@ def get_speaker_manager(c, args, meta_data_train):
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
)
)
# save file if path is defined
if out_path:
out_file_path = os.path.join(out_path, "speaker.json")
print(" > Saving `speaker.json` to {out_file_path}.")
if c.use_external_speaker_embedding_file and c.external_speaker_embedding_file:
speaker_manager.save_d_vectors_to_file(out_file_path)
else:
speaker_manager.save_speaker_ids_to_file(out_file_path)
return speaker_manager