mirror of https://github.com/coqui-ai/TTS.git
add multiband melgan config and rename loss weights for vocoder
parent
1a49ce8956
commit
604453bcb7
|
@ -53,10 +53,10 @@
|
|||
"use_hinge_gan_loss": false,
|
||||
"use_feat_match_loss": false, // use only with melgan discriminators
|
||||
|
||||
"stft_loss_alpha": 1,
|
||||
"mse_gan_loss_alpha": 1,
|
||||
"hinge_gan_loss_alpha": 1,
|
||||
"feat_match_loss_alpha": 10.0,
|
||||
"stft_loss_weight": 1,
|
||||
"mse_gan_loss_weight": 2.5,
|
||||
"hinge_gan_loss_weight": 2.5,
|
||||
"feat_match_loss_weight": 10.0,
|
||||
|
||||
"stft_loss_params": {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
|
|
|
@ -149,10 +149,10 @@ class GeneratorLoss(nn.Module):
|
|||
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
||||
self.use_feat_match_loss = C.use_feat_match_loss
|
||||
|
||||
self.stft_loss_alpha = C.stft_loss_alpha
|
||||
self.mse_gan_loss_alpha = C.mse_gan_loss_alpha
|
||||
self.hinge_gan_loss_alpha = C.hinge_gan_loss_alpha
|
||||
self.feat_match_loss_alpha = C.feat_match_loss_alpha
|
||||
self.stft_loss_weight = C.stft_loss_weight
|
||||
self.mse_gan_loss_weight = C.mse_gan_loss_weight
|
||||
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight
|
||||
self.feat_match_loss_weight = C.feat_match_loss_weight
|
||||
|
||||
if C.use_stft_loss:
|
||||
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
|
||||
|
@ -172,7 +172,7 @@ class GeneratorLoss(nn.Module):
|
|||
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat.squeeze(1), y.squeeze(1))
|
||||
return_dict['G_stft_loss_mg'] = stft_loss_mg
|
||||
return_dict['G_stft_loss_sc'] = stft_loss_sc
|
||||
loss += self.stft_loss_alpha * (stft_loss_mg + stft_loss_sc)
|
||||
loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
|
||||
|
||||
# Fake Losses
|
||||
if self.use_mse_gan_loss and scores_fake is not None:
|
||||
|
@ -185,7 +185,7 @@ class GeneratorLoss(nn.Module):
|
|||
fake_loss = self.mse_loss(scores_fake)
|
||||
mse_fake_loss = fake_loss
|
||||
return_dict['G_mse_fake_loss'] = mse_fake_loss
|
||||
loss += self.mse_gan_loss_alpha * mse_fake_loss
|
||||
loss += self.mse_gan_loss_weight * mse_fake_loss
|
||||
|
||||
if self.use_hinge_gan_loss and not scores_fake is not None:
|
||||
hinge_fake_loss = 0
|
||||
|
@ -197,13 +197,13 @@ class GeneratorLoss(nn.Module):
|
|||
fake_loss = self.hinge_loss(scores_fake)
|
||||
hinge_fake_loss = fake_loss
|
||||
return_dict['G_hinge_fake_loss'] = hinge_fake_loss
|
||||
loss += self.hinge_gan_loss_alpha * hinge_fake_loss
|
||||
loss += self.hinge_gan_loss_weight * hinge_fake_loss
|
||||
|
||||
# Feature Matching Loss
|
||||
if self.use_feat_match_loss and not feats_fake:
|
||||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||
return_dict['G_feat_match_loss'] = feat_match_loss
|
||||
loss += self.feat_match_loss_alpha * feat_match_loss
|
||||
loss += self.feat_match_loss_weight * feat_match_loss
|
||||
return_dict['G_loss'] = loss
|
||||
return return_dict
|
||||
|
||||
|
@ -217,8 +217,8 @@ class DiscriminatorLoss(nn.Module):
|
|||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
||||
|
||||
self.mse_gan_loss_alpha = C.mse_gan_loss_alpha
|
||||
self.hinge_gan_loss_alpha = C.hinge_gan_loss_alpha
|
||||
self.mse_gan_loss_weight = C.mse_gan_loss_weight
|
||||
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight
|
||||
|
||||
if C.use_mse_gan_loss:
|
||||
self.mse_loss = MSEDLoss()
|
||||
|
@ -247,7 +247,7 @@ class DiscriminatorLoss(nn.Module):
|
|||
return_dict['D_mse_gan_loss'] = mse_gan_loss
|
||||
return_dict['D_mse_gan_real_loss'] = mse_gan_real_loss
|
||||
return_dict['D_mse_gan_fake_loss'] = mse_gan_fake_loss
|
||||
loss += self.mse_gan_loss_alpha * mse_gan_loss
|
||||
loss += self.mse_gan_loss_weight * mse_gan_loss
|
||||
|
||||
if self.use_hinge_gan_loss:
|
||||
hinge_gan_loss = 0
|
||||
|
@ -267,7 +267,7 @@ class DiscriminatorLoss(nn.Module):
|
|||
return_dict['D_hinge_gan_loss'] = hinge_gan_loss
|
||||
return_dict['D_hinge_gan_real_loss'] = hinge_gan_real_loss
|
||||
return_dict['D_hinge_gan_fake_loss'] = hinge_gan_fake_loss
|
||||
loss += self.hinge_gan_loss_alpha * hinge_gan_loss
|
||||
loss += self.hinge_gan_loss_weight * hinge_gan_loss
|
||||
|
||||
return_dict['D_loss'] = loss
|
||||
return return_dict
|
Loading…
Reference in New Issue