refactor: only use keyword args in Synthesizer

pull/4115/head^2
Enno Hermann 2024-11-29 16:27:14 +01:00
parent 6927e0bb89
commit 546f43cb25
3 changed files with 14 additions and 13 deletions

View File

@ -407,18 +407,18 @@ def main():
# load models
synthesizer = Synthesizer(
tts_path,
tts_config_path,
speakers_file_path,
language_ids_file_path,
vocoder_path,
vocoder_config_path,
encoder_path,
encoder_config_path,
vc_path,
vc_config_path,
model_dir,
args.voice_dir,
tts_checkpoint=tts_path,
tts_config_path=tts_config_path,
tts_speakers_file=speakers_file_path,
tts_languages_file=language_ids_file_path,
vocoder_checkpoint=vocoder_path,
vocoder_config=vocoder_config_path,
encoder_checkpoint=encoder_path,
encoder_config=encoder_config_path,
vc_checkpoint=vc_path,
vc_config=vc_config_path,
model_dir=model_dir,
voice_dir=args.voice_dir,
).to(device)
# query speaker ids of a multi-speaker model.

View File

@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
class Synthesizer(nn.Module):
def __init__(
self,
*,
tts_checkpoint: str = "",
tts_config_path: str = "",
tts_speakers_file: str = "",

View File

@ -23,7 +23,7 @@ class SynthesizerTest(unittest.TestCase):
tts_root_path = get_tests_input_path()
tts_checkpoint = os.path.join(tts_root_path, "checkpoint_10.pth")
tts_config = os.path.join(tts_root_path, "dummy_model_config.json")
synthesizer = Synthesizer(tts_checkpoint, tts_config, None, None)
synthesizer = Synthesizer(tts_checkpoint=tts_checkpoint, tts_config_path=tts_config)
synthesizer.tts("Better this test works!!")
def test_split_into_sentences(self):