diff --git a/train.py b/train.py index 94ccfedb..1444e103 100644 --- a/train.py +++ b/train.py @@ -356,7 +356,7 @@ def evaluate(model, criterion, ap, global_step, epoch): mel_lengths, decoder_backward_output, alignments, alignment_lengths, text_lengths) if c.bidirectional_decoder: - keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(), + keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_b_loss'].item(), 'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()}) if c.ga_alpha > 0: keep_avg.update_values({'avg_ga_loss': loss_dict['ga_loss'].item()})