From 10524b885c39ab7f286bdbf90014ecbdb5302a0b Mon Sep 17 00:00:00 2001 From: Yves-Noel Weweler Date: Mon, 31 Dec 2018 13:29:39 +0100 Subject: [PATCH] Fix NameError: name 'model_dict' is not defined Closes #91 --- train.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/train.py b/train.py index 6d18e3c0..e32d0294 100644 --- a/train.py +++ b/train.py @@ -102,7 +102,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, mel_lengths = mel_lengths.cuda(non_blocking=True) linear_input = linear_input.cuda(non_blocking=True) stop_targets = stop_targets.cuda(non_blocking=True) - + # compute mask for padding mask = sequence_mask(text_lengths) @@ -400,17 +400,20 @@ def main(args): if args.restore_path: checkpoint = torch.load(args.restore_path) - model.load_state_dict(checkpoint['model']) - # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - # 1. filter out unnecessary keys - pretrained_dict = { - k: v - for k, v in checkpoint['model'].items() if k in model_dict - } - # 2. overwrite entries in the existing state dict - model_dict.update(pretrained_dict) - # 3. load the new state dict - model.load_state_dict(model_dict) + try: + model.load_state_dict(checkpoint['model']) + except: + model_dict = model.state_dict() + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + # 1. filter out unnecessary keys + pretrained_dict = { + k: v + for k, v in checkpoint['model'].items() if k in model_dict + } + # 2. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + # 3. load the new state dict + model.load_state_dict(model_dict) if use_cuda: model = model.cuda() criterion.cuda() @@ -418,7 +421,7 @@ def main(args): optimizer.load_state_dict(checkpoint['optimizer']) print( " > Model restored from step %d" % checkpoint['step'], flush=True) - start_epoch = checkpoint['epoch'] + start_epoch = checkpoint['epoch'] best_loss = checkpoint['linear_loss'] args.restore_step = checkpoint['step'] else: