From dc4a16d62e4a935d1a9a0556d417efeff9e468fa Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 28 Dec 2020 13:51:50 +0100 Subject: [PATCH] speedy speehc losses --- TTS/tts/layers/losses.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 4fa752cf..8f0d33c8 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -240,6 +240,23 @@ class GuidedAttentionLoss(torch.nn.Module): return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) +class Huber(nn.Module): + def forward(self, x, y, length=None): + """ + Shapes: + x: B x T + y: B x T + length: B + """ + mask = sequence_mask(sequence_length=length, max_len=y.size(1)).float() + return torch.nn.functional.smooth_l1_loss( + x * mask, y * mask, reduction='sum') / mask.sum() + + +######################## +# MODEL LOSS LAYERS +######################## + class TacotronLoss(torch.nn.Module): """Collection of Tacotron set-up based on provided config.""" def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4): @@ -403,8 +420,27 @@ class GlowTTSLoss(torch.nn.Module): return_dict['log_mle'] = log_mle return_dict['loss_dur'] = loss_dur - # check if any loss is NaN + # check if any loss is NaN for key, loss in return_dict.items(): if torch.isnan(loss): raise RuntimeError(f" [!] NaN loss with {key}.") return return_dict + + +class SpeedySpeechLoss(nn.Module): + def __init__(self, c): + super().__init__() + self.l1 = L1LossMasked(False) + self.ssim = SSIMLoss() + self.huber = Huber() + + self.ssim_alpha = c.ssim_alpha + self.huber_alpha = c.huber_alpha + self.l1_alpha = c.l1_alpha + + def forward(self, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens): + l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + huber_loss = self.huber(dur_output, dur_target, input_lens) + loss = l1_loss + ssim_loss + huber_loss + return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss}