save used model characters to the checkpoints

pull/3/head
Eren Gölge 2021-02-12 12:03:42 +00:00
parent 8b6fd76ad2
commit e774f68aee
3 changed files with 20 additions and 9 deletions

View File

@ -268,7 +268,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
model_loss=loss_dict['loss'])
# wait all kernels to be completed
@ -467,7 +467,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
# Audio processor
ap = AudioProcessor(**c.audio)
if 'characters' in c.keys():
@ -477,7 +477,10 @@ def main(args): # pylint: disable=redefined-outer-name
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
# set model characters
model_characters = phonemes if c.use_phonemes else symbols
num_chars = len(model_characters)
# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
@ -559,7 +562,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
OUT_PATH)
OUT_PATH, model_characters)
if __name__ == '__main__':

View File

@ -247,7 +247,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
model_loss=loss_dict['loss'])
# wait all kernels to be completed
@ -431,7 +431,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
# Audio processor
ap = AudioProcessor(**c.audio)
if 'characters' in c.keys():
@ -441,7 +441,10 @@ def main(args): # pylint: disable=redefined-outer-name
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
# set model characters
model_characters = phonemes if c.use_phonemes else symbols
num_chars = len(model_characters)
# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
@ -523,7 +526,7 @@ def main(args): # pylint: disable=redefined-outer-name
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
global_step, epoch, c.r,
OUT_PATH)
OUT_PATH, model_characters)
if __name__ == '__main__':

View File

@ -284,6 +284,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
optimizer_st=optimizer_st,
model_loss=loss_dict['postnet_loss'],
characters=model_characters,
scaler=scaler.state_dict() if c.mixed_precision else None)
# Diagnostic visualizations
@ -492,9 +493,11 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
global meta_data_train, meta_data_eval, speaker_mapping, symbols, phonemes, model_characters
# Audio processor
ap = AudioProcessor(**c.audio)
# setup custom characters if set in config file.
if 'characters' in c.keys():
symbols, phonemes = make_symbols(**c.characters)
@ -503,6 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model_characters = phonemes if c.use_phonemes else symbols
# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
@ -634,6 +638,7 @@ def main(args): # pylint: disable=redefined-outer-name
epoch,
c.r,
OUT_PATH,
model_characters,
scaler=scaler.state_dict() if c.mixed_precision else None
)