brushed up printing model load path and best loss path

pull/367/head
gerazov 2021-02-12 10:55:45 +01:00 committed by Eren Gölge
parent f2e474cd37
commit 2db40457e8
7 changed files with 25 additions and 15 deletions

View File

@ -500,6 +500,7 @@ def main(args): # pylint: disable=redefined-outer-name
criterion = GlowTTSLoss()
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
# TODO: fix optimizer init, model.cuda() needs to be called before
@ -517,7 +518,7 @@ def main(args): # pylint: disable=redefined-outer-name
for group in optimizer.param_groups:
group['initial_lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
print(f" > Model restored from step {checkpoint['step']:d}",
flush=True)
args.restore_step = checkpoint['step']
else:
@ -545,7 +546,8 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = float('inf')
print(" > Starting with inf best loss.")
else:
print(args.best_path)
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(f" > Starting with loaded last best loss {best_loss}.")

View File

@ -464,6 +464,7 @@ def main(args): # pylint: disable=redefined-outer-name
criterion = SpeedySpeechLoss(c)
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
# TODO: fix optimizer init, model.cuda() needs to be called before
@ -509,7 +510,8 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = float('inf')
print(" > Starting with inf best loss.")
else:
print(args.best_path)
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(f" > Starting with loaded last best loss {best_loss}.")

View File

@ -538,12 +538,13 @@ def main(args): # pylint: disable=redefined-outer-name
# setup criterion
criterion = TacotronLoss(c, stopnet_pos_weight=c.stopnet_pos_weight, ga_sigma=0.4)
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
print(" > Restoring Model.")
print(" > Restoring Model...")
model.load_state_dict(checkpoint['model'])
# optimizer restore
print(" > Restoring Optimizer.")
print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint['optimizer'])
if "scaler" in checkpoint and c.mixed_precision:
print(" > Restoring AMP Scaler...")
@ -551,7 +552,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.reinit_layers:
raise RuntimeError
except (KeyError, RuntimeError):
print(" > Partial model initialization.")
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
# torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
@ -589,7 +590,8 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = float('inf')
print(" > Starting with inf best loss.")
else:
print(args.best_path)
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(f" > Starting with loaded last best loss {best_loss}.")

View File

@ -485,6 +485,7 @@ def main(args): # pylint: disable=redefined-outer-name
criterion_disc = DiscriminatorLoss(c)
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
print(" > Restoring Generator Model...")
@ -523,7 +524,7 @@ def main(args): # pylint: disable=redefined-outer-name
for group in optimizer_disc.param_groups:
group['lr'] = c.lr_disc
print(" > Model restored from step %d" % checkpoint['step'],
print(f" > Model restored from step {checkpoint['step']:d}",
flush=True)
args.restore_step = checkpoint['step']
else:
@ -549,10 +550,11 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = float('inf')
print(" > Starting with inf best loss.")
else:
print(args.best_path)
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(f" > Starting with loaded last best loss {best_loss}.")
print(f" > Starting with best loss of {best_loss}.")
keep_best = c.get('keep_best', False)
keep_after = c.get('keep_after', 10000) # void if keep_best False

View File

@ -354,6 +354,7 @@ def main(args): # pylint: disable=redefined-outer-name
criterion.cuda()
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
print(" > Restoring Model...")
@ -397,7 +398,8 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = float('inf')
print(" > Starting with inf best loss.")
else:
print(args.best_path)
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(f" > Starting with loaded last best loss {best_loss}.")

View File

@ -383,6 +383,7 @@ def main(args): # pylint: disable=redefined-outer-name
# restore any checkpoint
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
checkpoint = torch.load(args.restore_path, map_location="cpu")
try:
print(" > Restoring Model...")
@ -420,7 +421,8 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = float('inf')
print(" > Starting with inf best loss.")
else:
print(args.best_path)
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(f" > Starting with loaded last best loss {best_loss}.")

View File

@ -157,7 +157,6 @@ def process_args(args, model_type):
args.restore_path, best_model = get_last_models(args.continue_path)
if not args.best_path:
args.best_path = best_model
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
@ -171,8 +170,7 @@ def process_args(args, model_type):
if model_class == "TTS":
check_config_tts(c)
elif model_class == "VOCODER":
print("Vocoder config checker not implemented, "
"skipping ...")
print("Vocoder config checker not implemented, skipping ...")
else:
raise ValueError(f"model type {model_type} not recognized!")