mirror of https://github.com/coqui-ai/TTS.git
Implement `ForwardTTSLoss`
parent
3abc3a1d32
commit
570d5971be
|
@ -236,10 +236,40 @@ class Huber(nn.Module):
|
|||
y: B x T
|
||||
length: B
|
||||
"""
|
||||
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).float()
|
||||
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float()
|
||||
return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum()
|
||||
|
||||
|
||||
class ForwardSumLoss(nn.Module):
|
||||
def __init__(self, blank_logprob=-1):
|
||||
super().__init__()
|
||||
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
|
||||
self.blank_logprob = blank_logprob
|
||||
|
||||
def forward(self, attn_logprob, in_lens, out_lens):
|
||||
key_lens = in_lens
|
||||
query_lens = out_lens
|
||||
attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob)
|
||||
|
||||
total_loss = 0.0
|
||||
for bid in range(attn_logprob.shape[0]):
|
||||
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
|
||||
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1]
|
||||
|
||||
curr_logprob = self.log_softmax(curr_logprob[None])[0]
|
||||
loss = self.ctc_loss(
|
||||
curr_logprob,
|
||||
target_seq,
|
||||
input_lengths=query_lens[bid : bid + 1],
|
||||
target_lengths=key_lens[bid : bid + 1],
|
||||
)
|
||||
total_loss = total_loss + loss
|
||||
|
||||
total_loss = total_loss / attn_logprob.shape[0]
|
||||
return total_loss
|
||||
|
||||
|
||||
########################
|
||||
# MODEL LOSS LAYERS
|
||||
########################
|
||||
|
@ -413,25 +443,6 @@ class GlowTTSLoss(torch.nn.Module):
|
|||
return return_dict
|
||||
|
||||
|
||||
class SpeedySpeechLoss(nn.Module):
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.l1 = L1LossMasked(False)
|
||||
self.ssim = SSIMLoss()
|
||||
self.huber = Huber()
|
||||
|
||||
self.ssim_alpha = c.ssim_alpha
|
||||
self.huber_alpha = c.huber_alpha
|
||||
self.l1_alpha = c.l1_alpha
|
||||
|
||||
def forward(self, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens):
|
||||
l1_loss = self.l1(decoder_output, decoder_target, decoder_output_lens)
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
huber_loss = self.huber(dur_output, dur_target, input_lens)
|
||||
loss = self.l1_alpha * l1_loss + self.ssim_alpha * ssim_loss + self.huber_alpha * huber_loss
|
||||
return {"loss": loss, "loss_l1": l1_loss, "loss_ssim": ssim_loss, "loss_dur": huber_loss}
|
||||
|
||||
|
||||
def mse_loss_custom(x, y):
|
||||
"""MSE loss using the torch back-end without reduction.
|
||||
It uses less VRAM than the raw code"""
|
||||
|
@ -660,51 +671,41 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
return return_dict
|
||||
|
||||
|
||||
class ForwardSumLoss(nn.Module):
|
||||
def __init__(self, blank_logprob=-1):
|
||||
super().__init__()
|
||||
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
|
||||
self.blank_logprob = blank_logprob
|
||||
class ForwardTTSLoss(nn.Module):
|
||||
"""Generic configurable ForwardTTS loss."""
|
||||
|
||||
def forward(self, attn_logprob, in_lens, out_lens):
|
||||
key_lens = in_lens
|
||||
query_lens = out_lens
|
||||
attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob)
|
||||
|
||||
total_loss = 0.0
|
||||
for bid in range(attn_logprob.shape[0]):
|
||||
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
|
||||
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1]
|
||||
|
||||
curr_logprob = self.log_softmax(curr_logprob[None])[0]
|
||||
loss = self.ctc_loss(
|
||||
curr_logprob,
|
||||
target_seq,
|
||||
input_lengths=query_lens[bid : bid + 1],
|
||||
target_lengths=key_lens[bid : bid + 1],
|
||||
)
|
||||
total_loss = total_loss + loss
|
||||
|
||||
total_loss = total_loss / attn_logprob.shape[0]
|
||||
return total_loss
|
||||
|
||||
|
||||
class FastPitchLoss(nn.Module):
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.spec_loss = MSELossMasked(False)
|
||||
self.ssim = SSIMLoss()
|
||||
self.dur_loss = MSELossMasked(False)
|
||||
self.pitch_loss = MSELossMasked(False)
|
||||
if c.spec_loss_type == "mse":
|
||||
self.spec_loss = MSELossMasked(False)
|
||||
elif c.spec_loss_type == "l1":
|
||||
self.spec_loss = L1LossMasked(False)
|
||||
else:
|
||||
raise ValueError(" [!] Unknown spec_loss_type {}".format(c.spec_loss_type))
|
||||
|
||||
if c.duration_loss_type == "mse":
|
||||
self.dur_loss = MSELossMasked(False)
|
||||
elif c.duration_loss_type == "l1":
|
||||
self.dur_loss = L1LossMasked(False)
|
||||
elif c.duration_loss_type == "huber":
|
||||
self.dur_loss = Huber()
|
||||
else:
|
||||
raise ValueError(" [!] Unknown duration_loss_type {}".format(c.duration_loss_type))
|
||||
|
||||
if c.model_args.use_aligner:
|
||||
self.aligner_loss = ForwardSumLoss()
|
||||
self.aligner_loss_alpha = c.aligner_loss_alpha
|
||||
|
||||
if c.model_args.use_pitch:
|
||||
self.pitch_loss = MSELossMasked(False)
|
||||
self.pitch_loss_alpha = c.pitch_loss_alpha
|
||||
|
||||
if c.use_ssim_loss:
|
||||
self.ssim = SSIMLoss() if c.use_ssim_loss else None
|
||||
self.ssim_loss_alpha = c.ssim_loss_alpha
|
||||
|
||||
self.spec_loss_alpha = c.spec_loss_alpha
|
||||
self.ssim_loss_alpha = c.ssim_loss_alpha
|
||||
self.dur_loss_alpha = c.dur_loss_alpha
|
||||
self.pitch_loss_alpha = c.pitch_loss_alpha
|
||||
self.aligner_loss_alpha = c.aligner_loss_alpha
|
||||
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
|
||||
|
||||
@staticmethod
|
||||
|
@ -731,7 +732,7 @@ class FastPitchLoss(nn.Module):
|
|||
):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
if self.ssim_loss_alpha > 0:
|
||||
if hasattr(self, "ssim_loss") and self.ssim_loss_alpha > 0:
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
loss = loss + self.ssim_loss_alpha * ssim_loss
|
||||
return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss
|
||||
|
@ -747,12 +748,12 @@ class FastPitchLoss(nn.Module):
|
|||
loss = loss + self.dur_loss_alpha * dur_loss
|
||||
return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss
|
||||
|
||||
if self.pitch_loss_alpha > 0:
|
||||
if hasattr(self, "pitch_loss") and self.pitch_loss_alpha > 0:
|
||||
pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens)
|
||||
loss = loss + self.pitch_loss_alpha * pitch_loss
|
||||
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
|
||||
|
||||
if self.aligner_loss_alpha > 0:
|
||||
if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0:
|
||||
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
|
||||
loss = loss + self.aligner_loss_alpha * aligner_loss
|
||||
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
||||
|
|
Loading…
Reference in New Issue