Make lint

pull/1726/head
Eren G??lge 2022-07-12 14:58:26 +02:00
parent c614f21982
commit 48a4f3647f
3 changed files with 5 additions and 12 deletions

View File

@ -133,11 +133,11 @@ class SSIMLoss(torch.nn.Module):
if ssim_loss.item() > 1.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
ssim_loss == 1.0
ssim_loss = torch.tensor([1.0])
if ssim_loss.item() < 0.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
ssim_loss == 0.0
ssim_loss = torch.tensor([0.0])
return ssim_loss

View File

@ -20,8 +20,7 @@ def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
return x.mean(dim=0)
elif reduction == "sum":
return x.sum(dim=0)
else:
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
def _validate_input(
@ -140,7 +139,7 @@ def ssim(
kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y)
_compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel
ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2)
ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, k1=k1, k2=k2)
ssim_val = ssim_map.mean(1)
cs = cs_map.mean(1)
@ -268,7 +267,6 @@ def _ssim_per_channel(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor,
data_range: Union[float, int] = 1.0,
k1: float = 0.01,
k2: float = 0.03,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
@ -278,7 +276,6 @@ def _ssim_per_channel(
x: An input tensor. Shape :math:`(N, C, H, W)`.
y: A target tensor. Shape :math:`(N, C, H, W)`.
kernel: 2D Gaussian kernel.
data_range: Maximum value range of images (usually 1.0 or 255).
k1: Algorithm parameter, K1 (small constant, see [1]).
k2: Algorithm parameter, K2 (small constant, see [1]).
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
@ -321,7 +318,6 @@ def _ssim_per_channel_complex(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor,
data_range: Union[float, int] = 1.0,
k1: float = 0.01,
k2: float = 0.03,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
@ -331,7 +327,6 @@ def _ssim_per_channel_complex(
x: An input tensor. Shape :math:`(N, C, H, W, 2)`.
y: A target tensor. Shape :math:`(N, C, H, W, 2)`.
kernel: 2-D gauss kernel.
data_range: Maximum value range of images (usually 1.0 or 255).
k1: Algorithm parameter, K1 (small constant, see [1]).
k2: Algorithm parameter, K2 (small constant, see [1]).
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.

View File

@ -1,7 +1,6 @@
import unittest
import torch as T
from torch.nn import functional
from TTS.tts.layers.losses import BCELossMasked, L1LossMasked, MSELossMasked, SSIMLoss
from TTS.tts.utils.helpers import sequence_mask
@ -208,8 +207,6 @@ class BCELossTest(unittest.TestCase):
layer = BCELossMasked(pos_weight=5.0)
length = T.tensor([95])
mask = sequence_mask(length, 100)
pos_weight = T.tensor([5.0])
target = (
1.0 - sequence_mask(length - 1, 100).float()
) # [0, 0, .... 1, 1] where the first 1 is the last mel frame
@ -236,6 +233,7 @@ class BCELossTest(unittest.TestCase):
self.assertEqual(loss.item(), 0.0)
# when pos_weight < 1 overweight the early stopping loss
loss_early = layer(early_x, target, length)
loss_late = layer(late_x, target, length)
self.assertGreater(loss_early.item(), loss_late.item())