some commeting for Generator loss and check if the argument is defines in the config file

pull/422/head
Eren Gölge 2021-04-06 11:02:16 +02:00
parent ff07c5f5e3
commit de3a04f104
1 changed files with 6 additions and 0 deletions

View File

@ -331,6 +331,12 @@ class GeneratorLoss(nn.Module):
return_dict['G_l1_spec_loss'] = l1_spec_loss
gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss
# L1 Spec loss
if self.use_l1_spec_loss:
l1_spec_loss = self.l1_spec_loss(y_hat, y)
return_dict['G_l1_spec_loss'] = l1_spec_loss
gen_loss += self.l1_spec_loss_weight * l1_spec_loss
# subband STFT Loss
if self.use_subband_stft_loss:
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)