From af2d36faeb535ddfea600ffffa26bf2b3ecef90a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 21 Apr 2021 13:08:25 +0200 Subject: [PATCH] update synthesize.py for multi-speaker setting --- TTS/bin/synthesize.py | 58 +++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 32 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index c0ecb1ab..75a167e9 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -102,7 +102,7 @@ def main(): parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) # args for multi-speaker synthesis - parser.add_argument("--speakers_json", type=str, help="JSON file for multi-speaker model.", default=None) + parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) parser.add_argument( "--speaker_idx", type=str, @@ -110,7 +110,12 @@ def main(): default=None, ) parser.add_argument("--gst_style", help="Wav path file for GST stylereference.", default=None) - + parser.add_argument( + "--list_speaker_idxs", + help="List available speaker ids for the defined multi-speaker model.", + default=False, + type=str2bool, + ) # aux args parser.add_argument( "--save_spectogram", @@ -131,6 +136,7 @@ def main(): model_path = None config_path = None + speakers_file_path = None vocoder_path = None vocoder_config_path = None @@ -139,54 +145,42 @@ def main(): manager.list_models() sys.exit() - # CASE2: load pre-trained models - if args.model_name is not None: + # CASE2: load pre-trained model paths + if args.model_name is not None and not args.model_path: model_path, config_path, model_item = manager.download_model(args.model_name) args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name - if args.vocoder_name is not None: + if args.vocoder_name is not None and not args.vocoder_path: vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) - # CASE3: load custome models + # CASE3: set custome model paths if args.model_path is not None: model_path = args.model_path config_path = args.config_path + speakers_file_path = args.speakers_file_path if args.vocoder_path is not None: vocoder_path = args.vocoder_path vocoder_config_path = args.vocoder_config_path - # RUN THE SYNTHESIS # load models - synthesizer = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, args.use_cuda) + synthesizer = Synthesizer( + model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, args.use_cuda + ) + # query speaker ids of a multi-speaker model. + if args.list_speaker_idxs: + print( + " > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." + ) + print(synthesizer.speaker_manager.speaker_ids) + return + + # RUN THE SYNTHESIS print(" > Text: {}".format(args.text)) - # # handle multi-speaker setting - # if not model_config.use_external_speaker_embedding_file and args.speaker_idx is not None: - # if args.speaker_idx.isdigit(): - # args.speaker_idx = int(args.speaker_idx) - # else: - # args.speaker_idx = None - # else: - # args.speaker_idx = None - - # if args.gst_style is None: - # if 'gst' in model_config.keys() and model_config.gst['gst_style_input'] is not None: - # gst_style = model_config.gst['gst_style_input'] - # else: - # gst_style = None - # else: - # # check if gst_style string is a dict, if is dict convert else use string - # try: - # gst_style = json.loads(args.gst_style) - # if max(map(int, gst_style.keys())) >= model_config.gst['gst_style_tokens']: - # raise RuntimeError("The highest value of the gst_style dictionary key must be less than the number of GST Tokens, \n Highest dictionary key value: {} \n Number of GST tokens: {}".format(max(map(int, gst_style.keys())), model_config.gst['gst_style_tokens'])) - # except ValueError: - # gst_style = args.gst_style - # kick it - wav = synthesizer.tts(args.text) + wav = synthesizer.tts(args.text, args.speaker_idx) # save the results print(" > Saving output to {}".format(args.out_path))