diff --git a/train.py b/train.py index 54ec06ef..2c2fd3c6 100644 --- a/train.py +++ b/train.py @@ -119,10 +119,12 @@ def train(model, criterion, data_loader, optimizer, epoch): # loss computation mel_loss = criterion(mel_output, mel_spec, mel_lengths) - linear_loss = 0.5 * criterion(linear_output, linear_spec, mel_lengths) \ - + 0.5 * criterion(linear_output[:, :, :n_priority_freq], - linear_spec[:, :, :n_priority_freq], - mel_lengths) + linear_loss = criterion(linear_output, linear_spec, mel_lengths) + if c.priority_freq: + linear_loss = 0.5 * linear_loss + + + 0.5 * criterion(linear_output[:, :, :n_priority_freq], + linear_spec[:, :, :n_priority_freq], + mel_lengths) loss = mel_loss + linear_loss if c.mk > 0.0: attention_loss = criterion(alignments, M, mel_lengths) @@ -244,10 +246,12 @@ def evaluate(model, criterion, data_loader, current_step): # loss computation mel_loss = criterion(mel_output, mel_spec, mel_lengths) - linear_loss = 0.5 * criterion(linear_output, linear_spec, mel_lengths) \ - + 0.5 * criterion(linear_output[:, :, :n_priority_freq], - linear_spec[:, :, :n_priority_freq], - mel_lengths) + linear_loss = criterion(linear_output, linear_spec, mel_lengths) + if c.priority_freq: + linear_loss = 0.5 * linear_loss + + + 0.5 * criterion(linear_output[:, :, :n_priority_freq], + linear_spec[:, :, :n_priority_freq], + mel_lengths) loss = mel_loss + linear_loss step_time = time.time() - start_time