Fix VITS loss bug

Fake and real features were given in the wrong args order to
the loss function
pull/1324/head
Eren Gölge 2022-02-05 20:33:36 +01:00
parent 4b96bfe925
commit 1a43e05460
1 changed files with 11 additions and 5 deletions

View File

@ -553,7 +553,6 @@ class VitsGeneratorLoss(nn.Module):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
@staticmethod
@ -629,9 +628,16 @@ class VitsGeneratorLoss(nn.Module):
mel_hat = self.stft(waveform_hat)
# compute losses
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
loss_kl = self.kl_loss(
z_p=z_p,
logs_q=logs_q,
m_p=m_p,
logs_p=logs_p,
z_mask=z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_feat = self.feature_loss(
feats_real=feats_disc_real,
feats_generated=feats_disc_fake) * self.feat_loss_alpha
loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
@ -675,7 +681,7 @@ class VitsDiscriminatorLoss(nn.Module):
def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0
return_dict = {}
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
loss_disc, _, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + return_dict["loss_disc"]
return_dict["loss"] = loss