mirror of https://github.com/coqui-ai/TTS.git
Fix restoring best_loss
Keep the default value if model checkpoint has no `model_loss`pull/725/head
parent
c8bbcdfd07
commit
c5d1dd9d1b
|
@ -78,7 +78,7 @@ class TrainingArgs(Coqpit):
|
|||
best_path: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
|
||||
"help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used"
|
||||
},
|
||||
)
|
||||
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
||||
|
@ -149,6 +149,7 @@ class Trainer:
|
|||
>>> trainer.fit()
|
||||
|
||||
TODO:
|
||||
- Wrap model for not calling .module in DDP.
|
||||
- Accumulate gradients b/w batches.
|
||||
- Deepspeed integration
|
||||
- Profiler integration.
|
||||
|
@ -331,7 +332,7 @@ class Trainer:
|
|||
print(" > Restoring Optimizer...")
|
||||
optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer)
|
||||
if "scaler" in checkpoint and self.use_amp_scaler and checkpoint["scaler"]:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
print(" > Restoring Scaler...")
|
||||
scaler = _restore_list_objs(checkpoint["scaler"], scaler)
|
||||
except (KeyError, RuntimeError):
|
||||
print(" > Partial model initialization...")
|
||||
|
@ -477,7 +478,7 @@ class Trainer:
|
|||
|
||||
# check nan loss
|
||||
if torch.isnan(loss_dict["loss"]).any():
|
||||
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
|
||||
raise RuntimeError(f" > NaN loss detected - {loss_dict}")
|
||||
|
||||
# set gradient clipping threshold
|
||||
if "grad_clip" in config and config.grad_clip is not None:
|
||||
|
@ -819,7 +820,7 @@ class Trainer:
|
|||
def test_run(self) -> None:
|
||||
"""Run test and log the results. Test run must be defined by the model.
|
||||
Model must return figures and audios to be logged by the Tensorboard."""
|
||||
if hasattr(self.model, "test_run"):
|
||||
if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")):
|
||||
if self.eval_loader is None:
|
||||
self.eval_loader = self.get_eval_dataloader(
|
||||
self.ap,
|
||||
|
@ -841,13 +842,20 @@ class Trainer:
|
|||
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
||||
|
||||
def _fit(self) -> None:
|
||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||
def _restore_best_loss(self):
|
||||
"""Restore the best loss from the args.best_path if provided else
|
||||
from the model (`args.restore_path` or `args.continue_path`) used for resuming the training"""
|
||||
if self.restore_step != 0 or self.args.best_path:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
||||
self.best_loss = load_fsspec(self.args.restore_path, map_location="cpu")["model_loss"]
|
||||
ch = load_fsspec(self.args.restore_path, map_location="cpu")
|
||||
if "model_loss" in ch:
|
||||
self.best_loss = ch["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||
|
||||
def _fit(self) -> None:
|
||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||
self._restore_best_loss()
|
||||
|
||||
self.total_steps_done = self.restore_step
|
||||
|
||||
for epoch in range(0, self.config.epochs):
|
||||
|
|
Loading…
Reference in New Issue