diff --git a/train.py b/train.py index e238757e..87751306 100644 --- a/train.py +++ b/train.py @@ -154,7 +154,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, if current_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, optimizer, linear_loss.item(), + save_checkpoint(model, optimizer, optimizer_st, linear_loss.item(), OUT_PATH, current_step, epoch) # Diagnostic visualizations @@ -379,8 +379,8 @@ def main(args): if args.restore_path: checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) - optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer.load_state_dict(checkpoint['optimizer']) + optimizer_st.load_state_dict(checkpoint['optimizer_st']) for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): @@ -388,9 +388,7 @@ def main(args): print(" > Model restored from step %d" % checkpoint['step']) start_epoch = checkpoint['step'] // len(train_loader) best_loss = checkpoint['linear_loss'] - start_epoch = 0 args.restore_step = checkpoint['step'] - optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr) else: args.restore_step = 0 print("\n > Starting a new training") diff --git a/utils/generic_utils.py b/utils/generic_utils.py index b16c6944..5e4487d5 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -78,7 +78,7 @@ def _trim_model_state_dict(state_dict): return new_state_dict -def save_checkpoint(model, optimizer, model_loss, out_path, +def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path, current_step, epoch): checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step) checkpoint_path = os.path.join(out_path, checkpoint_path) @@ -87,6 +87,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, new_state_dict = _trim_model_state_dict(model.state_dict()) state = {'model': new_state_dict, 'optimizer': optimizer.state_dict(), + 'optimizer_st': optimizer_st.state_dict(), 'step': current_step, 'epoch': epoch, 'linear_loss': model_loss,