diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 3e69e1ad..fb2e41b4 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -254,7 +254,7 @@ def main(): print(" > Text: {}".format(args.text)) # kick it - wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav) + wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav, args.gst_style) # save the results print(" > Saving output to {}".format(args.out_path)) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 5185139e..578c26c0 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -250,11 +250,11 @@ def synthesis( # GST processing style_mel = None custom_symbols = None - if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: - if isinstance(style_wav, dict): - style_mel = style_wav - else: - style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) + if style_wav: + style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) + elif CONFIG.has("gst") and CONFIG.gst and not style_wav: + if CONFIG.gst.gst_style_input_weights: + style_mel = CONFIG.gst.gst_style_input_weights if hasattr(model, "make_symbols"): custom_symbols = model.make_symbols(CONFIG) # preprocess the given text