From 546f43cb254793366f996deab33eb1cc88e915bd Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 29 Nov 2024 16:27:14 +0100 Subject: [PATCH] refactor: only use keyword args in Synthesizer --- TTS/bin/synthesize.py | 24 +++++++++++------------ TTS/utils/synthesizer.py | 1 + tests/inference_tests/test_synthesizer.py | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 20e429df..454f528a 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -407,18 +407,18 @@ def main(): # load models synthesizer = Synthesizer( - tts_path, - tts_config_path, - speakers_file_path, - language_ids_file_path, - vocoder_path, - vocoder_config_path, - encoder_path, - encoder_config_path, - vc_path, - vc_config_path, - model_dir, - args.voice_dir, + tts_checkpoint=tts_path, + tts_config_path=tts_config_path, + tts_speakers_file=speakers_file_path, + tts_languages_file=language_ids_file_path, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + encoder_checkpoint=encoder_path, + encoder_config=encoder_config_path, + vc_checkpoint=vc_path, + vc_config=vc_config_path, + model_dir=model_dir, + voice_dir=args.voice_dir, ).to(device) # query speaker ids of a multi-speaker model. diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index a158df60..73f596d1 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) class Synthesizer(nn.Module): def __init__( self, + *, tts_checkpoint: str = "", tts_config_path: str = "", tts_speakers_file: str = "", diff --git a/tests/inference_tests/test_synthesizer.py b/tests/inference_tests/test_synthesizer.py index ce4fc751..21cc1941 100644 --- a/tests/inference_tests/test_synthesizer.py +++ b/tests/inference_tests/test_synthesizer.py @@ -23,7 +23,7 @@ class SynthesizerTest(unittest.TestCase): tts_root_path = get_tests_input_path() tts_checkpoint = os.path.join(tts_root_path, "checkpoint_10.pth") tts_config = os.path.join(tts_root_path, "dummy_model_config.json") - synthesizer = Synthesizer(tts_checkpoint, tts_config, None, None) + synthesizer = Synthesizer(tts_checkpoint=tts_checkpoint, tts_config_path=tts_config) synthesizer.tts("Better this test works!!") def test_split_into_sentences(self):