From 53f13461b921f861d5b07f2de28892e4704b96e1 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 10 Jun 2020 13:45:43 +0200 Subject: [PATCH] better stats logging for TTS training --- train.py | 132 +++++++++++++++++++------------------------------------ 1 file changed, 45 insertions(+), 87 deletions(-) diff --git a/train.py b/train.py index 869557e6..51c73a9f 100644 --- a/train.py +++ b/train.py @@ -119,21 +119,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, verbose=(epoch == 0)) model.train() epoch_time = 0 - train_values = { - 'avg_postnet_loss': 0, - 'avg_decoder_loss': 0, - 'avg_stopnet_loss': 0, - 'avg_align_error': 0, - 'avg_step_time': 0, - 'avg_loader_time': 0 - } - if c.bidirectional_decoder: - train_values['avg_decoder_b_loss'] = 0 # decoder backward loss - train_values['avg_decoder_c_loss'] = 0 # decoder consistency loss - if c.ga_alpha > 0: - train_values['avg_ga_loss'] = 0 # guidede attention loss keep_avg = KeepAverage() - keep_avg.add_values(train_values) if use_cuda: batch_n_iter = int( len(data_loader.dataset) / (c.batch_size * num_gpus)) @@ -179,11 +165,6 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, mel_lengths, decoder_backward_output, alignments, alignment_lengths, alignments_backward, text_lengths) - if c.bidirectional_decoder: - keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(), - 'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()}) - if c.ga_alpha > 0: - keep_avg.update_values({'avg_ga_loss': loss_dict['ga_loss'].item()}) # backward pass loss_dict['loss'].backward() @@ -193,7 +174,6 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments) - keep_avg.update_value('avg_align_error', align_error) loss_dict['align_error'] = align_error # backpass and check the grad norm for stop loss @@ -208,23 +188,6 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, step_time = time.time() - start_time epoch_time += step_time - # update avg stats - update_train_values = { - 'avg_postnet_loss': float(loss_dict['postnet_loss'].item()), - 'avg_decoder_loss': float(loss_dict['decoder_loss'].item()), - 'avg_stopnet_loss': loss_dict['stopnet_loss'].item() \ - if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()), - 'avg_step_time': step_time, - 'avg_loader_time': loader_time - } - keep_avg.update_values(update_train_values) - - if global_step % c.print_step == 0: - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - avg_spec_length, avg_text_length, - step_time, loader_time, current_lr, - loss_dict, keep_avg.avg_values) - # aggregate losses from processes if num_gpus > 1: loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus) @@ -232,6 +195,30 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) if c.stopnet else loss_dict['stopnet_loss'] + # detach loss values + loss_dict_new = dict() + for key, value in loss_dict.items(): + if isinstance(value, int) or isinstance(value, float): + loss_dict_new[key] = value + else: + loss_dict_new[key] = value.item() + loss_dict = loss_dict_new + + # update avg stats + update_train_values = dict() + for key, value in loss_dict.items(): + update_train_values['avg_' + key] = value + update_train_values['avg_loader_time'] = loader_time + update_train_values['avg_step_time'] = step_time + keep_avg.update_values(update_train_values) + + # print training progress + if global_step % c.print_step == 0: + c_logger.print_train_step(batch_n_iter, num_iter, global_step, + avg_spec_length, avg_text_length, + step_time, loader_time, current_lr, + loss_dict, keep_avg.avg_values) + if args.rank == 0: # Plot Training Iter Stats # reduce TB load @@ -266,7 +253,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, "alignment": plot_alignment(align_img), } - if c.bidirectional_decoder: + if c.bidirectional_decoder or c.double_decoder_consistency: figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy()) tb_logger.tb_train_figures(global_step, figures) @@ -286,16 +273,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, # Plot Epoch Stats if args.rank == 0: - # Plot Training Epoch Stats - epoch_stats = { - "loss_postnet": keep_avg['avg_postnet_loss'], - "loss_decoder": keep_avg['avg_decoder_loss'], - "stopnet_loss": keep_avg['avg_stopnet_loss'], - "alignment_error": keep_avg['avg_align_error'], - "epoch_time": epoch_time - } - if c.ga_alpha > 0: - epoch_stats['guided_attention_loss'] = keep_avg['avg_ga_loss'] + epoch_stats = {"epoch_time": epoch_time} + epoch_stats.update(keep_avg.avg_values) tb_logger.tb_train_epoch_stats(global_step, epoch_stats) if c.tb_model_param_stats: tb_logger.tb_model_weights(model, global_step) @@ -307,20 +286,7 @@ def evaluate(model, criterion, ap, global_step, epoch): data_loader = setup_loader(ap, model.decoder.r, is_val=True) model.eval() epoch_time = 0 - eval_values_dict = { - 'avg_postnet_loss': 0, - 'avg_decoder_loss': 0, - 'avg_stopnet_loss': 0, - 'avg_align_error': 0 - } - if c.bidirectional_decoder: - eval_values_dict['avg_decoder_b_loss'] = 0 # decoder backward loss - eval_values_dict['avg_decoder_c_loss'] = 0 # decoder consistency loss - if c.ga_alpha > 0: - eval_values_dict['avg_ga_loss'] = 0 # guidede attention loss keep_avg = KeepAverage() - keep_avg.add_values(eval_values_dict) - c_logger.print_eval_start() if data_loader is not None: for num_iter, data in enumerate(data_loader): @@ -352,11 +318,6 @@ def evaluate(model, criterion, ap, global_step, epoch): mel_lengths, decoder_backward_output, alignments, alignment_lengths, alignments_backward, text_lengths) - if c.bidirectional_decoder: - keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_b_loss'].item(), - 'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()}) - if c.ga_alpha > 0: - keep_avg.update_values({'avg_ga_loss': loss_dict['ga_loss'].item()}) # step time step_time = time.time() - start_time @@ -364,7 +325,7 @@ def evaluate(model, criterion, ap, global_step, epoch): # compute alignment score align_error = 1 - alignment_diagonal_score(alignments) - keep_avg.update_value('avg_align_error', align_error) + loss_dict['align_error'] = align_error # aggregate losses from processes if num_gpus > 1: @@ -373,14 +334,20 @@ def evaluate(model, criterion, ap, global_step, epoch): if c.stopnet: loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) - keep_avg.update_values({ - 'avg_postnet_loss': - float(loss_dict['postnet_loss'].item()), - 'avg_decoder_loss': - float(loss_dict['decoder_loss'].item()), - 'avg_stopnet_loss': - float(loss_dict['stopnet_loss'].item()), - }) + # detach loss values + loss_dict_new = dict() + for key, value in loss_dict.items(): + if isinstance(value, int) or isinstance(value, float): + loss_dict_new[key] = value + else: + loss_dict_new[key] = value.item() + loss_dict = loss_dict_new + + # update avg stats + update_train_values = dict() + for key, value in loss_dict.items(): + update_train_values['avg_' + key] = value + keep_avg.update_values(update_train_values) if c.print_eval: c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) @@ -409,20 +376,11 @@ def evaluate(model, criterion, ap, global_step, epoch): c.audio["sample_rate"]) # Plot Validation Stats - epoch_stats = { - "loss_postnet": keep_avg['avg_postnet_loss'], - "loss_decoder": keep_avg['avg_decoder_loss'], - "stopnet_loss": keep_avg['avg_stopnet_loss'], - "alignment_error": keep_avg['avg_align_error'], - } - if c.bidirectional_decoder: - epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss'] + if c.bidirectional_decoder or c.double_decoder_consistency: align_b_img = alignments_backward[idx].data.cpu().numpy() - eval_figures['alignment_backward'] = plot_alignment(align_b_img) - if c.ga_alpha > 0: - epoch_stats['guided_attention_loss'] = keep_avg['avg_ga_loss'] - tb_logger.tb_eval_stats(global_step, epoch_stats) + eval_figures['alignment2'] = plot_alignment(align_b_img) + tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) tb_logger.tb_eval_figures(global_step, eval_figures) if args.rank == 0 and epoch > c.test_delay_epochs: