Fixup `trainer.py` 🛠️

pull/602/head
Eren Gölge 2021-06-21 16:49:30 +02:00
parent 9cb1062736
commit 01c4b22a2f
1 changed files with 6 additions and 5 deletions

View File

@ -462,12 +462,12 @@ class Trainer:
update_lr_scheduler = True
if self.use_amp_scaler:
if self.use_apex:
with amp.scale_loss(loss_dict["loss"], self.optimizer) as scaled_loss:
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
scaled_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
self.config.grad_clip,
)
grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer),
grad_clip,
)
else:
# model optimizer step in mixed precision mode
scaler.scale(loss_dict["loss"]).backward()
@ -739,6 +739,7 @@ class Trainer:
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
if audios is not None:
self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.tb_logger.tb_eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values)
def test_run(self) -> None:
"""Run test and log the results. Test run must be defined by the model.