diff --git a/train.py b/train.py index 53c5698d..fc6680a7 100644 --- a/train.py +++ b/train.py @@ -345,14 +345,14 @@ def main(args): model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n > Model restored from step %d\n" % args.restore_step) - start_epoch = checkpoint['step'] // len(dataloader) + start_epoch = checkpoint['step'] // len(train_loader) best_loss = checkpoint['linear_loss'] elif args.restore_path: checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n > Model restored from step %d\n" % checkpoint['step']) - start_epoch = checkpoint['step'] // len(dataloader) + start_epoch = checkpoint['step'] // len(train_loader) best_loss = checkpoint['linear_loss'] start_epoch = 0 args.restore_step = checkpoint['step']