mirror of https://github.com/coqui-ai/TTS.git
Fixup `trainer.py` 🛠️
parent
9cb1062736
commit
01c4b22a2f
|
@ -462,12 +462,12 @@ class Trainer:
|
|||
update_lr_scheduler = True
|
||||
if self.use_amp_scaler:
|
||||
if self.use_apex:
|
||||
with amp.scale_loss(loss_dict["loss"], self.optimizer) as scaled_loss:
|
||||
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
amp.master_params(self.optimizer),
|
||||
self.config.grad_clip,
|
||||
)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
amp.master_params(optimizer),
|
||||
grad_clip,
|
||||
)
|
||||
else:
|
||||
# model optimizer step in mixed precision mode
|
||||
scaler.scale(loss_dict["loss"]).backward()
|
||||
|
@ -739,6 +739,7 @@ class Trainer:
|
|||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||
if audios is not None:
|
||||
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
|
||||
|
||||
def test_run(self) -> None:
|
||||
"""Run test and log the results. Test run must be defined by the model.
|
||||
|
|
Loading…
Reference in New Issue