linter fixes

pull/367/head
Eren Gölge 2021-02-09 11:43:17 +00:00 committed by Eren Gölge
parent 2b5cb24db7
commit 3c961370e7
3 changed files with 4 additions and 5 deletions

View File

@ -515,7 +515,7 @@ def main(args): # pylint: disable=redefined-outer-name
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer, train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
scheduler, ap, global_step, scheduler, ap, global_step,
epoch) epoch)
eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
global_step, epoch) global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss'] target_loss = train_avg_loss_dict['avg_loss']

View File

@ -104,8 +104,7 @@ def get_last_checkpoint(path):
pass pass
if last_checkpoint is None: if last_checkpoint is None:
raise ValueError(f"No checkpoints in {path}!") raise ValueError(f"No checkpoints in {path}!")
else: return last_checkpoint
return last_checkpoint
def process_args(args, model_type): def process_args(args, model_type):
@ -193,7 +192,7 @@ def process_args(args, model_type):
if args.restore_path: if args.restore_path:
new_fields["restore_path"] = args.restore_path new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch() new_fields["github_branch"] = get_git_branch()
copy_model_files(c, args.config_path, copy_model_files(c, args.config_path,
out_path, new_fields) out_path, new_fields)
os.chmod(audio_path, 0o775) os.chmod(audio_path, 0o775)
os.chmod(out_path, 0o775) os.chmod(out_path, 0o775)

View File

@ -21,7 +21,7 @@ def test_phoneme_to_sequence():
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!" gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
assert text_hat == text_hat_with_params == gt assert text_hat == text_hat_with_params == gt
# multiple punctuations # multiple punctuations
text = "Be a voice, not an! echo?" text = "Be a voice, not an! echo?"
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)