save model r value for checkpoints

pull/10/head
Eren Golge 2019-08-16 13:11:51 +02:00
parent 446cd6fa06
commit 5acd9e82bd
1 changed files with 4 additions and 2 deletions

View File

@ -121,7 +121,8 @@ def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': model.decoder.r
}
torch.save(state, checkpoint_path)
@ -136,7 +137,8 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': model.decoder.r
}
best_loss = model_loss
bestmodel_path = 'best_model.pth.tar'