diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index a537afd6..8cfd750d 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -115,11 +115,16 @@ class BaseTacotron(BaseTTS): ): # pylint: disable=unused-argument, redefined-builtin state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) + # TODO: set r in run-time by taking it from the new config if "r" in state: + # set r from the state (for compatibility with older checkpoints) self.decoder.set_r(state["r"]) - else: - # set the reduction rate from the config values embedded in the checkpoint + elif "config" in state: + # set r from config used at training time (for inference) self.decoder.set_r(state["config"]["r"]) + else: + # set r from the new config (for new-models) + self.decoder.set_r(config.r) if eval: self.eval() print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")