From a6f73a18cb14ad598bdb57581524b334784426ff Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Tue, 12 Jul 2022 14:11:34 +0200 Subject: [PATCH] Fix BCELoss adressing #1192 --- TTS/tts/configs/tacotron_config.py | 4 +-- TTS/tts/layers/losses.py | 16 +++++------- tests/tts_tests/test_losses.py | 39 +++++++++++++++++++++++++++++- 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py index e25609ff..31f6ae9b 100644 --- a/TTS/tts/configs/tacotron_config.py +++ b/TTS/tts/configs/tacotron_config.py @@ -53,7 +53,7 @@ class TacotronConfig(BaseTTSConfig): enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True. stopnet_pos_weight (float): Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with - datasets with longer sentences. Defaults to 10. + datasets with longer sentences. Defaults to 0.2. max_decoder_steps (int): Max number of steps allowed for the decoder. Defaults to 50. encoder_in_features (int): @@ -161,7 +161,7 @@ class TacotronConfig(BaseTTSConfig): prenet_dropout_at_inference: bool = False stopnet: bool = True separate_stopnet: bool = True - stopnet_pos_weight: float = 10.0 + stopnet_pos_weight: float = 0.2 max_decoder_steps: int = 500 encoder_in_features: int = 256 decoder_in_features: int = 256 diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 325791c7..50c61d67 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -147,9 +147,6 @@ class AttentionEntropyLoss(nn.Module): """ Forces attention to be more decisive by penalizing soft attention weights - - TODO: arguments - TODO: unit_test """ entropy = torch.distributions.Categorical(probs=align).entropy() loss = (entropy / np.log(align.shape[1])).mean() @@ -157,9 +154,9 @@ class AttentionEntropyLoss(nn.Module): class BCELossMasked(nn.Module): - def __init__(self, pos_weight): + def __init__(self, pos_weight:float=None): super().__init__() - self.pos_weight = pos_weight + self.pos_weight = torch.tensor([pos_weight]) def forward(self, x, target, length): """ @@ -179,16 +176,15 @@ class BCELossMasked(nn.Module): Returns: loss: An average loss value in range [0, 1] masked by the length. """ - # mask: (batch, max_len, 1) target.requires_grad = False if length is not None: - mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() - x = x * mask - target = target * mask + # mask: (batch, max_len, 1) + mask = sequence_mask(sequence_length=length, max_len=target.size(1)) num_items = mask.sum() + loss = functional.binary_cross_entropy_with_logits(x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum") else: + loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum") num_items = torch.numel(x) - loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum") loss = loss / num_items return loss diff --git a/tests/tts_tests/test_losses.py b/tests/tts_tests/test_losses.py index 42627f0d..e4652408 100644 --- a/tests/tts_tests/test_losses.py +++ b/tests/tts_tests/test_losses.py @@ -1,8 +1,9 @@ import unittest import torch as T +from torch.nn import functional from TTS.tts.utils.helpers import sequence_mask -from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked +from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked, BCELossMasked class L1LossMaskedTests(unittest.TestCase): @@ -200,3 +201,39 @@ class SSIMLossTests(unittest.TestCase): mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 0, "0 vs {}".format(output.item()) + + +class BCELossTest(unittest.TestCase): + def test_in_out(self): # pylint: disable=no-self-use + layer = BCELossMasked(pos_weight=5.0) + + length = T.tensor([95]) + mask = sequence_mask(length, 100) + pos_weight = T.tensor([5.0]) + target = 1. - sequence_mask(length - 1, 100).float() # [0, 0, .... 1, 1] where the first 1 is the last mel frame + true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target + zero_x = T.zeros(target.shape) - 100. # simulate logits if it never stops decoding + early_x = -200. * sequence_mask(length - 3, 100).float() + 100. # simulate logits on early stopping + late_x = -200. * sequence_mask(length + 1, 100).float() + 100. # simulate logits on late stopping + + loss = layer(true_x, target, length) + self.assertEqual(loss.item(), 0.0) + + loss = layer(early_x, target, length) + self.assertAlmostEqual(loss.item(), 2.1053, places=4) + + loss = layer(late_x, target, length) + self.assertAlmostEqual(loss.item(), 5.2632, places=4) + + loss = layer(zero_x, target, length) + self.assertAlmostEqual(loss.item(), 5.2632, places=4) + + # pos_weight should be < 1 to penalize early stopping + layer = BCELossMasked(pos_weight=0.2) + loss = layer(true_x, target, length) + self.assertEqual(loss.item(), 0.0) + + # when pos_weight < 1 overweight the early stopping loss + loss_early = layer(early_x, target, length) + loss_late = layer(late_x, target, length) + self.assertGreater(loss_early.item(), loss_late.item())