diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index e705b1e0..1107b3c5 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -4,13 +4,15 @@ from torch import nn from torch.nn import functional as F -class TorchSTFT(): +class TorchSTFT(nn.Module): def __init__(self, n_fft, hop_length, win_length, window='hann_window'): """ Torch based STFT operation """ + super(TorchSTFT, self).__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length - self.window = getattr(torch, window)(win_length) + self.window = nn.Parameter(getattr(torch, window)(win_length), + requires_grad=False) def __call__(self, x): # B x D x T x 2 @@ -22,7 +24,8 @@ class TorchSTFT(): center=True, pad_mode="reflect", # compatible with audio.py normalized=False, - onesided=True) + onesided=True, + return_complex=False) M = o[:, :, :, 0] P = o[:, :, :, 1] return torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))