diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 827da751..0c94f91f 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -553,7 +553,6 @@ class VitsGeneratorLoss(nn.Module): rl = rl.float().detach() gl = gl.float() loss += torch.mean(torch.abs(rl - gl)) - return loss * 2 @staticmethod @@ -629,9 +628,16 @@ class VitsGeneratorLoss(nn.Module): mel_hat = self.stft(waveform_hat) # compute losses - loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha - loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha - loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha + loss_kl = self.kl_loss( + z_p=z_p, + logs_q=logs_q, + m_p=m_p, + logs_p=logs_p, + z_mask=z_mask.unsqueeze(1)) * self.kl_loss_alpha + loss_feat = self.feature_loss( + feats_real=feats_disc_real, + feats_generated=feats_disc_fake) * self.feat_loss_alpha + loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration @@ -675,7 +681,7 @@ class VitsDiscriminatorLoss(nn.Module): def forward(self, scores_disc_real, scores_disc_fake): loss = 0.0 return_dict = {} - loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake) + loss_disc, _, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake) return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss