pull/10/head
Eren Golge 2018-03-05 08:48:17 -08:00
parent 8517187511
commit 3888b31b3c
1 changed files with 2 additions and 2 deletions

View File

@ -345,14 +345,14 @@ def main(args):
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
print("\n > Model restored from step %d\n" % args.restore_step) print("\n > Model restored from step %d\n" % args.restore_step)
start_epoch = checkpoint['step'] // len(dataloader) start_epoch = checkpoint['step'] // len(train_loader)
best_loss = checkpoint['linear_loss'] best_loss = checkpoint['linear_loss']
elif args.restore_path: elif args.restore_path:
checkpoint = torch.load(args.restore_path) checkpoint = torch.load(args.restore_path)
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
print("\n > Model restored from step %d\n" % checkpoint['step']) print("\n > Model restored from step %d\n" % checkpoint['step'])
start_epoch = checkpoint['step'] // len(dataloader) start_epoch = checkpoint['step'] // len(train_loader)
best_loss = checkpoint['linear_loss'] best_loss = checkpoint['linear_loss']
start_epoch = 0 start_epoch = 0
args.restore_step = checkpoint['step'] args.restore_step = checkpoint['step']