mirror of https://github.com/coqui-ai/TTS.git
update align_tts_loss for trainer
parent
fc9a0fb8ce
commit
9203b863d9
|
@ -462,13 +462,12 @@ class MDNLoss(nn.Module):
|
|||
|
||||
class AlignTTSLoss(nn.Module):
|
||||
"""Modified AlignTTS Loss.
|
||||
Computes following losses
|
||||
Computes
|
||||
- L1 and SSIM losses from output spectrograms.
|
||||
- Huber loss for duration predictor.
|
||||
- MDNLoss for Mixture of Density Network.
|
||||
|
||||
All the losses are aggregated by a weighted sum with the loss alphas.
|
||||
Alphas can be scheduled based on number of steps.
|
||||
All loss values are aggregated by a weighted sum of the alpha values.
|
||||
|
||||
Args:
|
||||
c (dict): TTS model configuration.
|
||||
|
@ -487,9 +486,9 @@ class AlignTTSLoss(nn.Module):
|
|||
self.mdn_alpha = c.mdn_alpha
|
||||
|
||||
def forward(
|
||||
self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase
|
||||
self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, phase
|
||||
):
|
||||
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step)
|
||||
# ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step)
|
||||
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
|
||||
if phase == 0:
|
||||
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
|
||||
|
@ -507,36 +506,5 @@ class AlignTTSLoss(nn.Module):
|
|||
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
||||
loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss
|
||||
loss = self.spec_loss_alpha * spec_loss + self.ssim_alpha * ssim_loss + self.dur_loss_alpha * dur_loss + self.mdn_alpha * mdn_loss
|
||||
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
||||
|
||||
@staticmethod
|
||||
def _set_alpha(step, alpha_settings):
|
||||
"""Set the loss alpha wrt number of steps.
|
||||
Return the corresponding value if no schedule is set.
|
||||
|
||||
Example:
|
||||
Setting a alpha schedule.
|
||||
if ```alpha_settings``` is ```[[0, 1], [10000, 0.1]]``` then ```return_alpha == 1``` until 10k steps, then set to 0.1.
|
||||
if ```alpha_settings``` is a constant value then ```return_alpha``` is set to that constant.
|
||||
|
||||
Args:
|
||||
step (int): number of training steps.
|
||||
alpha_settings (int or list): constant alpha value or a list defining the schedule as explained above.
|
||||
"""
|
||||
return_alpha = None
|
||||
if isinstance(alpha_settings, list):
|
||||
for key, alpha in alpha_settings:
|
||||
if key < step:
|
||||
return_alpha = alpha
|
||||
elif isinstance(alpha_settings, (float, int)):
|
||||
return_alpha = alpha_settings
|
||||
return return_alpha
|
||||
|
||||
def set_alphas(self, step):
|
||||
"""Set the alpha values for all the loss functions"""
|
||||
ssim_alpha = self._set_alpha(step, self.ssim_alpha)
|
||||
dur_loss_alpha = self._set_alpha(step, self.dur_loss_alpha)
|
||||
spec_loss_alpha = self._set_alpha(step, self.spec_loss_alpha)
|
||||
mdn_alpha = self._set_alpha(step, self.mdn_alpha)
|
||||
return ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha
|
||||
|
|
Loading…
Reference in New Issue