Simplify grad_norm handling in trainer

pull/887/head
Eren Gölge 2021-10-19 16:33:04 +00:00
parent 3c7848e9b1
commit 588da1a24e
1 changed files with 8 additions and 8 deletions

View File

@ -362,7 +362,7 @@ class Trainer:
# override config values from command-line args
# TODO: Maybe it is better to do it outside
if len(coqpit_overrides) > 0:
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
config.parse_known_args(coqpit_overrides, arg_prefix="coqpit", relaxed_parser=True)
experiment_path = args.continue_path
# update the config.json fields and copy it to the output folder
@ -618,10 +618,8 @@ class Trainer:
else:
grad_clip = 0.0 # meaning no gradient clipping
if grad_clip <= 0:
grad_norm = 0
# optimizer step
grad_norm = 0
update_lr_scheduler = True
if self.use_amp_scaler:
if self.use_apex:
@ -636,13 +634,11 @@ class Trainer:
if grad_clip > 0:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip)
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
grad_norm = 0
scale_prev = scaler.get_scale()
scaler.step(optimizer)
scaler.update()
update_lr_scheduler = scale_prev <= scaler.get_scale()
loss_dict["amp_scaler"] = scaler.get_scale() # for logging
else:
# main model optimizer step
loss_dict["loss"].backward()
@ -650,6 +646,10 @@ class Trainer:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if isinstance(grad_norm ,torch.Tensor) and (torch.isnan(grad_norm) or torch.isinf(grad_norm)):
grad_norm = 0
step_time = time.time() - step_start_time
# setup lr
@ -1147,7 +1147,7 @@ class Trainer:
if isinstance(value, (int, float)):
loss_dict_detached[key] = value
else:
loss_dict_detached[key] = value.detach()
loss_dict_detached[key] = value.detach().item()
return loss_dict_detached
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict: