import numpy as np import torch from torch import nn from torch.nn import functional from TTS.utils.generic_utils import sequence_mask class L1LossMasked(nn.Module): def __init__(self, seq_len_norm): super(L1LossMasked, self).__init__() self.seq_len_norm = seq_len_norm def forward(self, x, target, length): """ Args: x: A Variable containing a FloatTensor of size (batch, max_len, dim) which contains the unnormalized probability for each class. target: A Variable containing a LongTensor of size (batch, max_len, dim) which contains the index of the true class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Returns: loss: An average loss value in range [0, 1] masked by the length. """ # mask: (batch, max_len, 1) target.requires_grad = False mask = sequence_mask( sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() if self.seq_len_norm: norm_w = mask / mask.sum(dim=1, keepdim=True) out_weights = norm_w.div(target.shape[0] * target.shape[2]) mask = mask.expand_as(x) loss = functional.l1_loss( x * mask, target * mask, reduction='none') loss = loss.mul(out_weights.to(loss.device)).sum() else: mask = mask.expand_as(x) loss = functional.l1_loss( x * mask, target * mask, reduction='sum') loss = loss / mask.sum() return loss class MSELossMasked(nn.Module): def __init__(self, seq_len_norm): super(MSELossMasked, self).__init__() self.seq_len_norm = seq_len_norm def forward(self, x, target, length): """ Args: x: A Variable containing a FloatTensor of size (batch, max_len, dim) which contains the unnormalized probability for each class. target: A Variable containing a LongTensor of size (batch, max_len, dim) which contains the index of the true class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Returns: loss: An average loss value in range [0, 1] masked by the length. """ # mask: (batch, max_len, 1) target.requires_grad = False mask = sequence_mask( sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() if self.seq_len_norm: norm_w = mask / mask.sum(dim=1, keepdim=True) out_weights = norm_w.div(target.shape[0] * target.shape[2]) mask = mask.expand_as(x) loss = functional.mse_loss( x * mask, target * mask, reduction='none') loss = loss.mul(out_weights.to(loss.device)).sum() else: mask = mask.expand_as(x) loss = functional.mse_loss( x * mask, target * mask, reduction='sum') loss = loss / mask.sum() return loss class AttentionEntropyLoss(nn.Module): # pylint: disable=R0201 def forward(self, align): """ Forces attention to be more decisive by penalizing soft attention weights TODO: arguments TODO: unit_test """ entropy = torch.distributions.Categorical(probs=align).entropy() loss = (entropy / np.log(align.shape[1])).mean() return loss