mirror of https://github.com/coqui-ai/TTS.git
Make lint
parent
c614f21982
commit
48a4f3647f
|
@ -133,11 +133,11 @@ class SSIMLoss(torch.nn.Module):
|
|||
|
||||
if ssim_loss.item() > 1.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
|
||||
ssim_loss == 1.0
|
||||
ssim_loss = torch.tensor([1.0])
|
||||
|
||||
if ssim_loss.item() < 0.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
|
||||
ssim_loss == 0.0
|
||||
ssim_loss = torch.tensor([0.0])
|
||||
|
||||
return ssim_loss
|
||||
|
||||
|
|
|
@ -20,8 +20,7 @@ def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
|
|||
return x.mean(dim=0)
|
||||
elif reduction == "sum":
|
||||
return x.sum(dim=0)
|
||||
else:
|
||||
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
|
||||
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
|
||||
|
||||
|
||||
def _validate_input(
|
||||
|
@ -140,7 +139,7 @@ def ssim(
|
|||
|
||||
kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y)
|
||||
_compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel
|
||||
ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2)
|
||||
ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, k1=k1, k2=k2)
|
||||
ssim_val = ssim_map.mean(1)
|
||||
cs = cs_map.mean(1)
|
||||
|
||||
|
@ -268,7 +267,6 @@ def _ssim_per_channel(
|
|||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
data_range: Union[float, int] = 1.0,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
@ -278,7 +276,6 @@ def _ssim_per_channel(
|
|||
x: An input tensor. Shape :math:`(N, C, H, W)`.
|
||||
y: A target tensor. Shape :math:`(N, C, H, W)`.
|
||||
kernel: 2D Gaussian kernel.
|
||||
data_range: Maximum value range of images (usually 1.0 or 255).
|
||||
k1: Algorithm parameter, K1 (small constant, see [1]).
|
||||
k2: Algorithm parameter, K2 (small constant, see [1]).
|
||||
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
||||
|
@ -321,7 +318,6 @@ def _ssim_per_channel_complex(
|
|||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
data_range: Union[float, int] = 1.0,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
@ -331,7 +327,6 @@ def _ssim_per_channel_complex(
|
|||
x: An input tensor. Shape :math:`(N, C, H, W, 2)`.
|
||||
y: A target tensor. Shape :math:`(N, C, H, W, 2)`.
|
||||
kernel: 2-D gauss kernel.
|
||||
data_range: Maximum value range of images (usually 1.0 or 255).
|
||||
k1: Algorithm parameter, K1 (small constant, see [1]).
|
||||
k2: Algorithm parameter, K2 (small constant, see [1]).
|
||||
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import unittest
|
||||
|
||||
import torch as T
|
||||
from torch.nn import functional
|
||||
|
||||
from TTS.tts.layers.losses import BCELossMasked, L1LossMasked, MSELossMasked, SSIMLoss
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
|
@ -208,8 +207,6 @@ class BCELossTest(unittest.TestCase):
|
|||
layer = BCELossMasked(pos_weight=5.0)
|
||||
|
||||
length = T.tensor([95])
|
||||
mask = sequence_mask(length, 100)
|
||||
pos_weight = T.tensor([5.0])
|
||||
target = (
|
||||
1.0 - sequence_mask(length - 1, 100).float()
|
||||
) # [0, 0, .... 1, 1] where the first 1 is the last mel frame
|
||||
|
@ -236,6 +233,7 @@ class BCELossTest(unittest.TestCase):
|
|||
self.assertEqual(loss.item(), 0.0)
|
||||
|
||||
# when pos_weight < 1 overweight the early stopping loss
|
||||
|
||||
loss_early = layer(early_x, target, length)
|
||||
loss_late = layer(late_x, target, length)
|
||||
self.assertGreater(loss_early.item(), loss_late.item())
|
||||
|
|
Loading…
Reference in New Issue