Update tacotron `r` init

pull/800/head
Eren Gölge 2021-09-10 17:26:23 +00:00
parent ab37fa9c39
commit d5f256b34c
1 changed files with 7 additions and 2 deletions

View File

@ -115,11 +115,16 @@ class BaseTacotron(BaseTTS):
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
# TODO: set r in run-time by taking it from the new config
if "r" in state:
# set r from the state (for compatibility with older checkpoints)
self.decoder.set_r(state["r"])
else:
# set the reduction rate from the config values embedded in the checkpoint
elif "config" in state:
# set r from config used at training time (for inference)
self.decoder.set_r(state["config"]["r"])
else:
# set r from the new config (for new-models)
self.decoder.set_r(config.r)
if eval:
self.eval()
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")