mirror of https://github.com/coqui-ai/TTS.git
Checkpoint stop token optimizer
parent
fce6bd27b8
commit
dfd0bc1831
6
train.py
6
train.py
|
@ -154,7 +154,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
if current_step % c.save_step == 0:
|
if current_step % c.save_step == 0:
|
||||||
if c.checkpoint:
|
if c.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model, optimizer, linear_loss.item(),
|
save_checkpoint(model, optimizer, optimizer_st, linear_loss.item(),
|
||||||
OUT_PATH, current_step, epoch)
|
OUT_PATH, current_step, epoch)
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
|
@ -379,8 +379,8 @@ def main(args):
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
optimizer_st.load_state_dict(checkpoint['optimizer_st'])
|
||||||
for state in optimizer.state.values():
|
for state in optimizer.state.values():
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
if torch.is_tensor(v):
|
if torch.is_tensor(v):
|
||||||
|
@ -388,9 +388,7 @@ def main(args):
|
||||||
print(" > Model restored from step %d" % checkpoint['step'])
|
print(" > Model restored from step %d" % checkpoint['step'])
|
||||||
start_epoch = checkpoint['step'] // len(train_loader)
|
start_epoch = checkpoint['step'] // len(train_loader)
|
||||||
best_loss = checkpoint['linear_loss']
|
best_loss = checkpoint['linear_loss']
|
||||||
start_epoch = 0
|
|
||||||
args.restore_step = checkpoint['step']
|
args.restore_step = checkpoint['step']
|
||||||
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
print("\n > Starting a new training")
|
print("\n > Starting a new training")
|
||||||
|
|
|
@ -78,7 +78,7 @@ def _trim_model_state_dict(state_dict):
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, model_loss, out_path,
|
def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
|
||||||
current_step, epoch):
|
current_step, epoch):
|
||||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||||
|
@ -87,6 +87,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path,
|
||||||
new_state_dict = _trim_model_state_dict(model.state_dict())
|
new_state_dict = _trim_model_state_dict(model.state_dict())
|
||||||
state = {'model': new_state_dict,
|
state = {'model': new_state_dict,
|
||||||
'optimizer': optimizer.state_dict(),
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'optimizer_st': optimizer_st.state_dict(),
|
||||||
'step': current_step,
|
'step': current_step,
|
||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
'linear_loss': model_loss,
|
'linear_loss': model_loss,
|
||||||
|
|
Loading…
Reference in New Issue