diff --git a/config.json b/config.json index 38d865f9..2a171ad1 100644 --- a/config.json +++ b/config.json @@ -79,6 +79,7 @@ "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages "text_cleaner": "phoneme_cleaners", "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. - "style_wav_for_test": null // path to style wav file to be used in TacotronGST inference. + "style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference. + "use_gst": false // TACOTRON ONLY: use global style tokens } diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 1053d221..bfa72a35 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -253,13 +253,14 @@ def setup_model(num_chars, num_speakers, c): print(" > Using model: {}".format(c.model)) MyModel = importlib.import_module('TTS.models.' + c.model.lower()) MyModel = getattr(MyModel, c.model) - if c.model.lower() in ["tacotron", "tacotrongst"]: + if c.model.lower() in "tacotron": model = MyModel( num_chars=num_chars, num_speakers=num_speakers, r=c.r, linear_dim=1025, mel_dim=80, + gst=c.use_gst, memory_size=c.memory_size, attn_win=c.windowing, attn_norm=c.attention_norm,