From 588da1a24e3f828496bcc78510df5b3128f5fcc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 19 Oct 2021 16:33:04 +0000 Subject: [PATCH] Simplify grad_norm handling in trainer --- TTS/trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index 7a38616d..40d1ab6f 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -362,7 +362,7 @@ class Trainer: # override config values from command-line args # TODO: Maybe it is better to do it outside if len(coqpit_overrides) > 0: - config.parse_known_args(coqpit_overrides, relaxed_parser=True) + config.parse_known_args(coqpit_overrides, arg_prefix="coqpit", relaxed_parser=True) experiment_path = args.continue_path # update the config.json fields and copy it to the output folder @@ -618,10 +618,8 @@ class Trainer: else: grad_clip = 0.0 # meaning no gradient clipping - if grad_clip <= 0: - grad_norm = 0 - # optimizer step + grad_norm = 0 update_lr_scheduler = True if self.use_amp_scaler: if self.use_apex: @@ -636,13 +634,11 @@ class Trainer: if grad_clip > 0: scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip) - # pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN - if torch.isnan(grad_norm) or torch.isinf(grad_norm): - grad_norm = 0 scale_prev = scaler.get_scale() scaler.step(optimizer) scaler.update() update_lr_scheduler = scale_prev <= scaler.get_scale() + loss_dict["amp_scaler"] = scaler.get_scale() # for logging else: # main model optimizer step loss_dict["loss"].backward() @@ -650,6 +646,10 @@ class Trainer: grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() + # pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN + if isinstance(grad_norm ,torch.Tensor) and (torch.isnan(grad_norm) or torch.isinf(grad_norm)): + grad_norm = 0 + step_time = time.time() - step_start_time # setup lr @@ -1147,7 +1147,7 @@ class Trainer: if isinstance(value, (int, float)): loss_dict_detached[key] = value else: - loss_dict_detached[key] = value.detach() + loss_dict_detached[key] = value.detach().item() return loss_dict_detached def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict: