mirror of https://github.com/coqui-ai/TTS.git
Update tacotron `r` init
parent
ab37fa9c39
commit
d5f256b34c
|
@ -115,11 +115,16 @@ class BaseTacotron(BaseTTS):
|
|||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
# TODO: set r in run-time by taking it from the new config
|
||||
if "r" in state:
|
||||
# set r from the state (for compatibility with older checkpoints)
|
||||
self.decoder.set_r(state["r"])
|
||||
else:
|
||||
# set the reduction rate from the config values embedded in the checkpoint
|
||||
elif "config" in state:
|
||||
# set r from config used at training time (for inference)
|
||||
self.decoder.set_r(state["config"]["r"])
|
||||
else:
|
||||
# set r from the new config (for new-models)
|
||||
self.decoder.set_r(config.r)
|
||||
if eval:
|
||||
self.eval()
|
||||
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
|
||||
|
|
Loading…
Reference in New Issue