pull/10/head
Eren Golge 2018-04-26 05:46:24 -07:00
parent d143616ada
commit 961d240534
1 changed files with 12 additions and 8 deletions

View File

@ -119,7 +119,9 @@ def train(model, criterion, data_loader, optimizer, epoch):
# loss computation
mel_loss = criterion(mel_output, mel_spec, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_spec, mel_lengths) \
linear_loss = criterion(linear_output, linear_spec, mel_lengths)
if c.priority_freq:
linear_loss = 0.5 * linear_loss +
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec[:, :, :n_priority_freq],
mel_lengths)
@ -244,7 +246,9 @@ def evaluate(model, criterion, data_loader, current_step):
# loss computation
mel_loss = criterion(mel_output, mel_spec, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_spec, mel_lengths) \
linear_loss = criterion(linear_output, linear_spec, mel_lengths)
if c.priority_freq:
linear_loss = 0.5 * linear_loss +
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec[:, :, :n_priority_freq],
mel_lengths)