mirror of https://github.com/coqui-ai/TTS.git
Fix VITS loss bug
Fake and real features were given in the wrong args order to the loss functionpull/1324/head
parent
4b96bfe925
commit
1a43e05460
|
@ -553,7 +553,6 @@ class VitsGeneratorLoss(nn.Module):
|
|||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2
|
||||
|
||||
@staticmethod
|
||||
|
@ -629,9 +628,16 @@ class VitsGeneratorLoss(nn.Module):
|
|||
mel_hat = self.stft(waveform_hat)
|
||||
|
||||
# compute losses
|
||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||
loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
|
||||
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_kl = self.kl_loss(
|
||||
z_p=z_p,
|
||||
logs_q=logs_q,
|
||||
m_p=m_p,
|
||||
logs_p=logs_p,
|
||||
z_mask=z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||
loss_feat = self.feature_loss(
|
||||
feats_real=feats_disc_real,
|
||||
feats_generated=feats_disc_fake) * self.feat_loss_alpha
|
||||
loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
|
@ -675,7 +681,7 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
def forward(self, scores_disc_real, scores_disc_fake):
|
||||
loss = 0.0
|
||||
return_dict = {}
|
||||
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
|
||||
loss_disc, _, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake)
|
||||
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
|
||||
loss = loss + return_dict["loss_disc"]
|
||||
return_dict["loss"] = loss
|
||||
|
|
Loading…
Reference in New Issue