diff --git a/train.py b/train.py index a53015c0..f637b91b 100644 --- a/train.py +++ b/train.py @@ -106,6 +106,8 @@ def train(model, criterion, data_loader, optimizer, epoch): # create attention mask # TODO: vectorize + N = text_input_var.shape[1] + T = mel_spec_var.shape[1] M = np.zeros([N, T]) for t in range(T): for n in range(N): @@ -126,7 +128,7 @@ def train(model, criterion, data_loader, optimizer, epoch): + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[:, :, :n_priority_freq], mel_lengths_var) - attention_loss = criterion(M, alignments, mel_lengths_var) + attention_loss = criterion(Mg, alignments, mel_lengths_var) loss = mel_loss + linear_loss + 0.2 * attention_loss # backpass and check the grad norm