add multiband melgan config and rename loss weights for vocoder

pull/1/head
erogol 2020-05-31 18:56:57 +02:00
parent 1a49ce8956
commit 604453bcb7
2 changed files with 16 additions and 16 deletions

View File

@ -53,10 +53,10 @@
"use_hinge_gan_loss": false,
"use_feat_match_loss": false, // use only with melgan discriminators
"stft_loss_alpha": 1,
"mse_gan_loss_alpha": 1,
"hinge_gan_loss_alpha": 1,
"feat_match_loss_alpha": 10.0,
"stft_loss_weight": 1,
"mse_gan_loss_weight": 2.5,
"hinge_gan_loss_weight": 2.5,
"feat_match_loss_weight": 10.0,
"stft_loss_params": {
"n_ffts": [1024, 2048, 512],

View File

@ -149,10 +149,10 @@ class GeneratorLoss(nn.Module):
self.use_hinge_gan_loss = C.use_hinge_gan_loss
self.use_feat_match_loss = C.use_feat_match_loss
self.stft_loss_alpha = C.stft_loss_alpha
self.mse_gan_loss_alpha = C.mse_gan_loss_alpha
self.hinge_gan_loss_alpha = C.hinge_gan_loss_alpha
self.feat_match_loss_alpha = C.feat_match_loss_alpha
self.stft_loss_weight = C.stft_loss_weight
self.mse_gan_loss_weight = C.mse_gan_loss_weight
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight
self.feat_match_loss_weight = C.feat_match_loss_weight
if C.use_stft_loss:
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
@ -172,7 +172,7 @@ class GeneratorLoss(nn.Module):
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat.squeeze(1), y.squeeze(1))
return_dict['G_stft_loss_mg'] = stft_loss_mg
return_dict['G_stft_loss_sc'] = stft_loss_sc
loss += self.stft_loss_alpha * (stft_loss_mg + stft_loss_sc)
loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
# Fake Losses
if self.use_mse_gan_loss and scores_fake is not None:
@ -185,7 +185,7 @@ class GeneratorLoss(nn.Module):
fake_loss = self.mse_loss(scores_fake)
mse_fake_loss = fake_loss
return_dict['G_mse_fake_loss'] = mse_fake_loss
loss += self.mse_gan_loss_alpha * mse_fake_loss
loss += self.mse_gan_loss_weight * mse_fake_loss
if self.use_hinge_gan_loss and not scores_fake is not None:
hinge_fake_loss = 0
@ -197,13 +197,13 @@ class GeneratorLoss(nn.Module):
fake_loss = self.hinge_loss(scores_fake)
hinge_fake_loss = fake_loss
return_dict['G_hinge_fake_loss'] = hinge_fake_loss
loss += self.hinge_gan_loss_alpha * hinge_fake_loss
loss += self.hinge_gan_loss_weight * hinge_fake_loss
# Feature Matching Loss
if self.use_feat_match_loss and not feats_fake:
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
return_dict['G_feat_match_loss'] = feat_match_loss
loss += self.feat_match_loss_alpha * feat_match_loss
loss += self.feat_match_loss_weight * feat_match_loss
return_dict['G_loss'] = loss
return return_dict
@ -217,8 +217,8 @@ class DiscriminatorLoss(nn.Module):
self.use_mse_gan_loss = C.use_mse_gan_loss
self.use_hinge_gan_loss = C.use_hinge_gan_loss
self.mse_gan_loss_alpha = C.mse_gan_loss_alpha
self.hinge_gan_loss_alpha = C.hinge_gan_loss_alpha
self.mse_gan_loss_weight = C.mse_gan_loss_weight
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight
if C.use_mse_gan_loss:
self.mse_loss = MSEDLoss()
@ -247,7 +247,7 @@ class DiscriminatorLoss(nn.Module):
return_dict['D_mse_gan_loss'] = mse_gan_loss
return_dict['D_mse_gan_real_loss'] = mse_gan_real_loss
return_dict['D_mse_gan_fake_loss'] = mse_gan_fake_loss
loss += self.mse_gan_loss_alpha * mse_gan_loss
loss += self.mse_gan_loss_weight * mse_gan_loss
if self.use_hinge_gan_loss:
hinge_gan_loss = 0
@ -267,7 +267,7 @@ class DiscriminatorLoss(nn.Module):
return_dict['D_hinge_gan_loss'] = hinge_gan_loss
return_dict['D_hinge_gan_real_loss'] = hinge_gan_real_loss
return_dict['D_hinge_gan_fake_loss'] = hinge_gan_fake_loss
loss += self.hinge_gan_loss_alpha * hinge_gan_loss
loss += self.hinge_gan_loss_weight * hinge_gan_loss
return_dict['D_loss'] = loss
return return_dict