From 01c4b22a2fa63d0edd35b2c58e2da5cc663555dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 21 Jun 2021 16:49:30 +0200 Subject: [PATCH] =?UTF-8?q?Fixup=20`trainer.py`=20=F0=9F=9B=A0=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TTS/trainer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index 8b7be3d1..ec6d4417 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -462,12 +462,12 @@ class Trainer: update_lr_scheduler = True if self.use_amp_scaler: if self.use_apex: - with amp.scale_loss(loss_dict["loss"], self.optimizer) as scaled_loss: + with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss: scaled_loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm_( - amp.master_params(self.optimizer), - self.config.grad_clip, - ) + grad_norm = torch.nn.utils.clip_grad_norm_( + amp.master_params(optimizer), + grad_clip, + ) else: # model optimizer step in mixed precision mode scaler.scale(loss_dict["loss"]).backward() @@ -739,6 +739,7 @@ class Trainer: self.tb_logger.tb_eval_figures(self.total_steps_done, figures) if audios is not None: self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate) + self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values) def test_run(self) -> None: """Run test and log the results. Test run must be defined by the model.