From 62158f5906e8151003450033e29a6c41eaaba078 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 25 Apr 2018 05:43:29 -0700 Subject: [PATCH] make attn guiding optional --- train.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/train.py b/train.py index a3bb8c26..e798c778 100644 --- a/train.py +++ b/train.py @@ -91,9 +91,6 @@ def train(model, criterion, data_loader, optimizer, epoch): params_group['lr'] = current_lr optimizer.zero_grad() - - # setup mk - mk = mk_decay(c.mk, c.epochs, epoch) # convert inputs to variables text_input_var = Variable(text_input) @@ -109,9 +106,11 @@ def train(model, criterion, data_loader, optimizer, epoch): linear_spec_var = linear_spec_var.cuda() # create attention mask - N = text_input_var.shape[1] - T = mel_spec_var.shape[1] // c.r - M = create_attn_mask(N, T, 0.03) + if c.mk > 0.0: + N = text_input_var.shape[1] + T = mel_spec_var.shape[1] // c.r + M = create_attn_mask(N, T, 0.03) + mk = mk_decay(c.mk, c.epochs, epoch) # forward pass mel_output, linear_output, alignments =\ @@ -123,9 +122,10 @@ def train(model, criterion, data_loader, optimizer, epoch): + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[:, :, :n_priority_freq], mel_lengths_var) - attention_loss = criterion(alignments, M, mel_lengths_var) - print(mk) - loss = mel_loss + linear_loss + mk * attention_loss + loss = mel_loss + linear_loss + if c.mk > 0.0: + attention_loss = criterion(alignments, M, mel_lengths_var) + loss += mk * attention_loss # backpass and check the grad norm loss.backward() @@ -155,7 +155,6 @@ def train(model, criterion, data_loader, optimizer, epoch): tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0], current_step) tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step) - tb.add_scalar('TrainIterLoss/AttnLoss', attention_loss.data[0], current_step) tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'], current_step) tb.add_scalar('Params/GradNorm', grad_norm, current_step) @@ -196,14 +195,15 @@ def train(model, criterion, data_loader, optimizer, epoch): avg_linear_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1) - avg_attn_loss /= (num_iter + 1) avg_total_loss = avg_mel_loss + avg_linear_loss # Plot Training Epoch Stats tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step) tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step) tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step) - tb.add_scalar('TrainEpochLoss/AttnLoss', avg_attn_loss, current_step) + if c.mk > 0: + avg_attn_loss /= (num_iter + 1) + tb.add_scalar('TrainEpochLoss/AttnLoss', avg_attn_loss, current_step) tb.add_scalar('Time/EpochTime', epoch_time, epoch) epoch_time = 0