diff --git a/train.py b/train.py index f6d73a4a..8581132a 100644 --- a/train.py +++ b/train.py @@ -237,7 +237,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, # save model save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, optimizer_st=optimizer_st, - model_loss=loss_dict['postnet_loss'].item()) + model_loss=loss_dict['postnet_loss']) # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy()