diff --git a/train.py b/train.py index f52d24c1..b9f5fefb 100644 --- a/train.py +++ b/train.py @@ -327,6 +327,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, return keep_avg['avg_postnet_loss'], global_step +@torch.no_grad() def evaluate(model, criterion, criterion_st, ap, global_step, epoch): data_loader = setup_loader(ap, model.decoder.r, is_val=True) if c.use_speaker_embedding: @@ -346,125 +347,124 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): keep_avg.add_values(eval_values_dict) print("\n > Validation") - with torch.no_grad(): - if data_loader is not None: - for num_iter, data in enumerate(data_loader): - start_time = time.time() + if data_loader is not None: + for num_iter, data in enumerate(data_loader): + start_time = time.time() - # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data) - assert mel_input.shape[1] % model.decoder.r == 0 + # format data + text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data) + assert mel_input.shape[1] % model.decoder.r == 0 - # forward pass model - if c.bidirectional_decoder: - decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids) - else: - decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + # forward pass model + if c.bidirectional_decoder: + decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( + text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + else: + decoder_output, postnet_output, alignments, stop_tokens = model( + text_input, text_lengths, mel_input, speaker_ids=speaker_ids) - # loss computation - stop_loss = criterion_st( - stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) - if c.loss_masking: - decoder_loss = criterion(decoder_output, mel_input, - mel_lengths) - if c.model in ["Tacotron", "TacotronGST"]: - postnet_loss = criterion(postnet_output, linear_input, - mel_lengths) - else: - postnet_loss = criterion(postnet_output, mel_input, - mel_lengths) - else: - decoder_loss = criterion(decoder_output, mel_input) - if c.model in ["Tacotron", "TacotronGST"]: - postnet_loss = criterion(postnet_output, linear_input) - else: - postnet_loss = criterion(postnet_output, mel_input) - loss = decoder_loss + postnet_loss + stop_loss - - # backward decoder loss - if c.bidirectional_decoder: - if c.loss_masking: - decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths) - else: - decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input) - decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output) - loss += decoder_backward_loss + decoder_c_loss - keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()}) - - step_time = time.time() - start_time - epoch_time += step_time - - # compute alignment score - align_score = alignment_diagonal_score(alignments) - keep_avg.update_value('avg_align_score', align_score) - - # aggregate losses from processes - if num_gpus > 1: - postnet_loss = reduce_tensor(postnet_loss.data, num_gpus) - decoder_loss = reduce_tensor(decoder_loss.data, num_gpus) - if c.stopnet: - stop_loss = reduce_tensor(stop_loss.data, num_gpus) - - keep_avg.update_values({ - 'avg_postnet_loss': - float(postnet_loss.item()), - 'avg_decoder_loss': - float(decoder_loss.item()), - 'avg_stop_loss': - float(stop_loss.item()), - }) - - if num_iter % c.print_step == 0: - print( - " | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} " - "StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}" - .format(loss.item(), postnet_loss.item(), - keep_avg['avg_postnet_loss'], - decoder_loss.item(), - keep_avg['avg_decoder_loss'], stop_loss.item(), - keep_avg['avg_stop_loss'], align_score, - keep_avg['avg_align_score']), - flush=True) - - if args.rank == 0: - # Diagnostic visualizations - idx = np.random.randint(mel_input.shape[0]) - const_spec = postnet_output[idx].data.cpu().numpy() - gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [ - "Tacotron", "TacotronGST" - ] else mel_input[idx].data.cpu().numpy() - align_img = alignments[idx].data.cpu().numpy() - - eval_figures = { - "prediction": plot_spectrogram(const_spec, ap), - "ground_truth": plot_spectrogram(gt_spec, ap), - "alignment": plot_alignment(align_img) - } - - # Sample audio + # loss computation + stop_loss = criterion_st( + stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) + if c.loss_masking: + decoder_loss = criterion(decoder_output, mel_input, + mel_lengths) if c.model in ["Tacotron", "TacotronGST"]: - eval_audio = ap.inv_spectrogram(const_spec.T) + postnet_loss = criterion(postnet_output, linear_input, + mel_lengths) else: - eval_audio = ap.inv_mel_spectrogram(const_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, - c.audio["sample_rate"]) + postnet_loss = criterion(postnet_output, mel_input, + mel_lengths) + else: + decoder_loss = criterion(decoder_output, mel_input) + if c.model in ["Tacotron", "TacotronGST"]: + postnet_loss = criterion(postnet_output, linear_input) + else: + postnet_loss = criterion(postnet_output, mel_input) + loss = decoder_loss + postnet_loss + stop_loss - # Plot Validation Stats - epoch_stats = { - "loss_postnet": keep_avg['avg_postnet_loss'], - "loss_decoder": keep_avg['avg_decoder_loss'], - "stop_loss": keep_avg['avg_stop_loss'], - "alignment_score": keep_avg['avg_align_score'] - } + # backward decoder loss + if c.bidirectional_decoder: + if c.loss_masking: + decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths) + else: + decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input) + decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output) + loss += decoder_backward_loss + decoder_c_loss + keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()}) - if c.bidirectional_decoder: - epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss'] - align_b_img = alignments_backward[idx].data.cpu().numpy() - eval_figures['alignment_backward'] = plot_alignment(align_b_img) - tb_logger.tb_eval_stats(global_step, epoch_stats) - tb_logger.tb_eval_figures(global_step, eval_figures) + step_time = time.time() - start_time + epoch_time += step_time + + # compute alignment score + align_score = alignment_diagonal_score(alignments) + keep_avg.update_value('avg_align_score', align_score) + + # aggregate losses from processes + if num_gpus > 1: + postnet_loss = reduce_tensor(postnet_loss.data, num_gpus) + decoder_loss = reduce_tensor(decoder_loss.data, num_gpus) + if c.stopnet: + stop_loss = reduce_tensor(stop_loss.data, num_gpus) + + keep_avg.update_values({ + 'avg_postnet_loss': + float(postnet_loss.item()), + 'avg_decoder_loss': + float(decoder_loss.item()), + 'avg_stop_loss': + float(stop_loss.item()), + }) + + if num_iter % c.print_step == 0: + print( + " | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} " + "StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}" + .format(loss.item(), postnet_loss.item(), + keep_avg['avg_postnet_loss'], + decoder_loss.item(), + keep_avg['avg_decoder_loss'], stop_loss.item(), + keep_avg['avg_stop_loss'], align_score, + keep_avg['avg_align_score']), + flush=True) + + if args.rank == 0: + # Diagnostic visualizations + idx = np.random.randint(mel_input.shape[0]) + const_spec = postnet_output[idx].data.cpu().numpy() + gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [ + "Tacotron", "TacotronGST" + ] else mel_input[idx].data.cpu().numpy() + align_img = alignments[idx].data.cpu().numpy() + + eval_figures = { + "prediction": plot_spectrogram(const_spec, ap), + "ground_truth": plot_spectrogram(gt_spec, ap), + "alignment": plot_alignment(align_img) + } + + # Sample audio + if c.model in ["Tacotron", "TacotronGST"]: + eval_audio = ap.inv_spectrogram(const_spec.T) + else: + eval_audio = ap.inv_mel_spectrogram(const_spec.T) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, + c.audio["sample_rate"]) + + # Plot Validation Stats + epoch_stats = { + "loss_postnet": keep_avg['avg_postnet_loss'], + "loss_decoder": keep_avg['avg_decoder_loss'], + "stop_loss": keep_avg['avg_stop_loss'], + "alignment_score": keep_avg['avg_align_score'] + } + + if c.bidirectional_decoder: + epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss'] + align_b_img = alignments_backward[idx].data.cpu().numpy() + eval_figures['alignment_backward'] = plot_alignment(align_b_img) + tb_logger.tb_eval_stats(global_step, epoch_stats) + tb_logger.tb_eval_figures(global_step, eval_figures) if args.rank == 0 and epoch > c.test_delay_epochs: if c.test_sentences_file is None: