mirror of https://github.com/coqui-ai/TTS.git
speedy speehc losses
parent
d62cac7252
commit
dc4a16d62e
|
@ -240,6 +240,23 @@ class GuidedAttentionLoss(torch.nn.Module):
|
|||
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2)
|
||||
|
||||
|
||||
class Huber(nn.Module):
|
||||
def forward(self, x, y, length=None):
|
||||
"""
|
||||
Shapes:
|
||||
x: B x T
|
||||
y: B x T
|
||||
length: B
|
||||
"""
|
||||
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).float()
|
||||
return torch.nn.functional.smooth_l1_loss(
|
||||
x * mask, y * mask, reduction='sum') / mask.sum()
|
||||
|
||||
|
||||
########################
|
||||
# MODEL LOSS LAYERS
|
||||
########################
|
||||
|
||||
class TacotronLoss(torch.nn.Module):
|
||||
"""Collection of Tacotron set-up based on provided config."""
|
||||
def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4):
|
||||
|
@ -403,8 +420,27 @@ class GlowTTSLoss(torch.nn.Module):
|
|||
return_dict['log_mle'] = log_mle
|
||||
return_dict['loss_dur'] = loss_dur
|
||||
|
||||
# check if any loss is NaN
|
||||
# check if any loss is NaN
|
||||
for key, loss in return_dict.items():
|
||||
if torch.isnan(loss):
|
||||
raise RuntimeError(f" [!] NaN loss with {key}.")
|
||||
return return_dict
|
||||
|
||||
|
||||
class SpeedySpeechLoss(nn.Module):
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.l1 = L1LossMasked(False)
|
||||
self.ssim = SSIMLoss()
|
||||
self.huber = Huber()
|
||||
|
||||
self.ssim_alpha = c.ssim_alpha
|
||||
self.huber_alpha = c.huber_alpha
|
||||
self.l1_alpha = c.l1_alpha
|
||||
|
||||
def forward(self, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens):
|
||||
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
huber_loss = self.huber(dur_output, dur_target, input_lens)
|
||||
loss = l1_loss + ssim_loss + huber_loss
|
||||
return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss}
|
||||
|
|
Loading…
Reference in New Issue