mirror of https://github.com/coqui-ai/TTS.git
save used model characters to the checkpoints
parent
8b6fd76ad2
commit
e774f68aee
|
@ -268,7 +268,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
|
||||
model_loss=loss_dict['loss'])
|
||||
|
||||
# wait all kernels to be completed
|
||||
|
@ -467,7 +467,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if 'characters' in c.keys():
|
||||
|
@ -477,7 +477,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
@ -559,7 +562,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||
OUT_PATH)
|
||||
OUT_PATH, model_characters)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -247,7 +247,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
|
||||
model_loss=loss_dict['loss'])
|
||||
|
||||
# wait all kernels to be completed
|
||||
|
@ -431,7 +431,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if 'characters' in c.keys():
|
||||
|
@ -441,7 +441,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
|
||||
|
@ -523,7 +526,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
|
||||
global_step, epoch, c.r,
|
||||
OUT_PATH)
|
||||
OUT_PATH, model_characters)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -284,6 +284,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
||||
optimizer_st=optimizer_st,
|
||||
model_loss=loss_dict['postnet_loss'],
|
||||
characters=model_characters,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||
|
||||
# Diagnostic visualizations
|
||||
|
@ -492,9 +493,11 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||
global meta_data_train, meta_data_eval, speaker_mapping, symbols, phonemes, model_characters
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
|
||||
# setup custom characters if set in config file.
|
||||
if 'characters' in c.keys():
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
|
@ -503,6 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
@ -634,6 +638,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
epoch,
|
||||
c.r,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue