mirror of https://github.com/coqui-ai/TTS.git
Simplify grad_norm handling in trainer
parent
3c7848e9b1
commit
588da1a24e
|
@ -362,7 +362,7 @@ class Trainer:
|
|||
# override config values from command-line args
|
||||
# TODO: Maybe it is better to do it outside
|
||||
if len(coqpit_overrides) > 0:
|
||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||
config.parse_known_args(coqpit_overrides, arg_prefix="coqpit", relaxed_parser=True)
|
||||
experiment_path = args.continue_path
|
||||
|
||||
# update the config.json fields and copy it to the output folder
|
||||
|
@ -618,10 +618,8 @@ class Trainer:
|
|||
else:
|
||||
grad_clip = 0.0 # meaning no gradient clipping
|
||||
|
||||
if grad_clip <= 0:
|
||||
grad_norm = 0
|
||||
|
||||
# optimizer step
|
||||
grad_norm = 0
|
||||
update_lr_scheduler = True
|
||||
if self.use_amp_scaler:
|
||||
if self.use_apex:
|
||||
|
@ -636,13 +634,11 @@ class Trainer:
|
|||
if grad_clip > 0:
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip)
|
||||
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
|
||||
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
||||
grad_norm = 0
|
||||
scale_prev = scaler.get_scale()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
update_lr_scheduler = scale_prev <= scaler.get_scale()
|
||||
loss_dict["amp_scaler"] = scaler.get_scale() # for logging
|
||||
else:
|
||||
# main model optimizer step
|
||||
loss_dict["loss"].backward()
|
||||
|
@ -650,6 +646,10 @@ class Trainer:
|
|||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
|
||||
if isinstance(grad_norm ,torch.Tensor) and (torch.isnan(grad_norm) or torch.isinf(grad_norm)):
|
||||
grad_norm = 0
|
||||
|
||||
step_time = time.time() - step_start_time
|
||||
|
||||
# setup lr
|
||||
|
@ -1147,7 +1147,7 @@ class Trainer:
|
|||
if isinstance(value, (int, float)):
|
||||
loss_dict_detached[key] = value
|
||||
else:
|
||||
loss_dict_detached[key] = value.detach()
|
||||
loss_dict_detached[key] = value.detach().item()
|
||||
return loss_dict_detached
|
||||
|
||||
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
|
||||
|
|
Loading…
Reference in New Issue