From 3b5592abcf41fffcb0c17858167bbd9228fbd970 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Fri, 29 Oct 2021 17:09:10 +0200 Subject: [PATCH] fix test vits --- TTS/trainer.py | 2 +- TTS/tts/configs/vits_config.py | 22 +++---------------- TTS/tts/datasets/dataset.py | 3 +-- TTS/tts/models/base_tts.py | 9 ++++++-- TTS/tts/models/vits.py | 5 +---- tests/tts_tests/test_vits_d-vectors_train.py | 3 +-- .../tts_tests/test_vits_multilingual_train.py | 3 ++- .../tts_tests/test_vits_speaker_emb_train.py | 2 +- 8 files changed, 17 insertions(+), 32 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index e8911ba3..665f2589 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -261,7 +261,7 @@ class Trainer: self.run_get_model(self.config, get_model) if hasattr(self.model, "init_multilingual"): - self.model.init_multilingual(self.config, self.data_train + self.data_eval) + self.model.init_multilingual(self.config, self.train_samples + self.eval_samples) config = self.config.model_args if hasattr(self.config, "model_args") else self.config # save speakers json if config.use_language_embedding and self.model.language_manager.num_languages > 1: diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index ece414a6..a6f2210d 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -154,22 +154,6 @@ class VitsConfig(BaseTTSConfig): d_vector_dim: int = None def __post_init__(self): - # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. - if self.num_speakers > 0: - self.model_args.num_speakers = self.num_speakers - - # speaker embedding settings - if self.use_speaker_embedding: - self.model_args.use_speaker_embedding = True - if self.speakers_file: - self.model_args.speakers_file = self.speakers_file - if self.speaker_embedding_channels: - self.model_args.speaker_embedding_channels = self.speaker_embedding_channels - - # d-vector settings - if self.use_d_vector_file: - self.model_args.use_d_vector_file = True - if self.d_vector_dim is not None and self.d_vector_dim > 0: - self.model_args.d_vector_dim = self.d_vector_dim - if self.d_vector_file: - self.model_args.d_vector_file = self.d_vector_file + for key in self.model_args.keys(): + if hasattr(self, key): + self[key] = self.model_args[key] diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 38af1469..c2818897 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -404,8 +404,7 @@ class TTSDataset(Dataset): # get language ids from language names if self.language_id_mapping is not None: - language_names = [batch[idx]["language_name"] for idx in ids_sorted_decreasing] - language_ids = [self.language_id_mapping[ln] for ln in language_names] + language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]] else: language_ids = None # get pre-computed d-vectors diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 9d722222..df6c52f3 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -245,8 +245,13 @@ class BaseTTS(BaseModel): # setup multi-speaker attributes if hasattr(self, "speaker_manager") and self.speaker_manager is not None: - speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None - d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None + if hasattr(config, "model_args"): + speaker_id_mapping = self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None + config.use_d_vector_file = config.model_args.use_d_vector_file + else: + speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None else: speaker_id_mapping = None d_vector_mapping = None diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 078d4973..bc503cb5 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -376,8 +376,7 @@ class Vits(BaseTTS): data (List, optional): Dataset items to infer number of speakers. Defaults to None. """ self.embedded_speaker_dim = 0 - if hasattr(config, "model_args"): - config = config.model_args + config = config.model_args self.num_speakers = config.num_speakers @@ -1033,7 +1032,6 @@ class Vits(BaseTTS): test_audios = {} test_figures = {} test_sentences = self.config.test_sentences - for idx, s_info in enumerate(test_sentences): try: aux_inputs = self.get_aux_input_from_test_setences(s_info) @@ -1051,7 +1049,6 @@ class Vits(BaseTTS): use_griffin_lim=True, do_trim_silence=False, ).values() - test_audios["{}-audio".format(idx)] = wav test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) except: # pylint: disable=bare-except diff --git a/tests/tts_tests/test_vits_d-vectors_train.py b/tests/tts_tests/test_vits_d-vectors_train.py index af0e0eba..213669f5 100644 --- a/tests/tts_tests/test_vits_d-vectors_train.py +++ b/tests/tts_tests/test_vits_d-vectors_train.py @@ -3,7 +3,7 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import VitsConfig +from TTS.tts.configs.vits_config import VitsConfig config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -33,7 +33,6 @@ config.audio.do_trim_silence = True config.audio.trim_db = 60 # active multispeaker d-vec mode -config.model_args.use_speaker_embedding = True config.model_args.use_d_vector_file = True config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" config.model_args.d_vector_dim = 256 diff --git a/tests/tts_tests/test_vits_multilingual_train.py b/tests/tts_tests/test_vits_multilingual_train.py index 10e66b81..664de57e 100644 --- a/tests/tts_tests/test_vits_multilingual_train.py +++ b/tests/tts_tests/test_vits_multilingual_train.py @@ -3,7 +3,8 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import BaseDatasetConfig, VitsConfig +from TTS.tts.configs.vits_config import VitsConfig +from TTS.config.shared_configs import BaseDatasetConfig config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") diff --git a/tests/tts_tests/test_vits_speaker_emb_train.py b/tests/tts_tests/test_vits_speaker_emb_train.py index 7028a983..6cc1dabd 100644 --- a/tests/tts_tests/test_vits_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_train.py @@ -3,7 +3,7 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs import VitsConfig +from TTS.tts.configs.vits_config import VitsConfig config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs")