update align_tts_loss for trainer

pull/506/head
Eren Gölge 2021-05-27 10:25:40 +02:00
parent fc9a0fb8ce
commit 9203b863d9
1 changed files with 5 additions and 37 deletions

View File

@ -462,13 +462,12 @@ class MDNLoss(nn.Module):
class AlignTTSLoss(nn.Module):
"""Modified AlignTTS Loss.
Computes following losses
Computes
- L1 and SSIM losses from output spectrograms.
- Huber loss for duration predictor.
- MDNLoss for Mixture of Density Network.
All the losses are aggregated by a weighted sum with the loss alphas.
Alphas can be scheduled based on number of steps.
All loss values are aggregated by a weighted sum of the alpha values.
Args:
c (dict): TTS model configuration.
@ -487,9 +486,9 @@ class AlignTTSLoss(nn.Module):
self.mdn_alpha = c.mdn_alpha
def forward(
self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase
self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, phase
):
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step)
# ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step)
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
if phase == 0:
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
@ -507,36 +506,5 @@ class AlignTTSLoss(nn.Module):
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss
loss = self.spec_loss_alpha * spec_loss + self.ssim_alpha * ssim_loss + self.dur_loss_alpha * dur_loss + self.mdn_alpha * mdn_loss
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
@staticmethod
def _set_alpha(step, alpha_settings):
"""Set the loss alpha wrt number of steps.
Return the corresponding value if no schedule is set.
Example:
Setting a alpha schedule.
if ```alpha_settings``` is ```[[0, 1], [10000, 0.1]]``` then ```return_alpha == 1``` until 10k steps, then set to 0.1.
if ```alpha_settings``` is a constant value then ```return_alpha``` is set to that constant.
Args:
step (int): number of training steps.
alpha_settings (int or list): constant alpha value or a list defining the schedule as explained above.
"""
return_alpha = None
if isinstance(alpha_settings, list):
for key, alpha in alpha_settings:
if key < step:
return_alpha = alpha
elif isinstance(alpha_settings, (float, int)):
return_alpha = alpha_settings
return return_alpha
def set_alphas(self, step):
"""Set the alpha values for all the loss functions"""
ssim_alpha = self._set_alpha(step, self.ssim_alpha)
dur_loss_alpha = self._set_alpha(step, self.dur_loss_alpha)
spec_loss_alpha = self._set_alpha(step, self.spec_loss_alpha)
mdn_alpha = self._set_alpha(step, self.mdn_alpha)
return ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha