mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev' of https://github.com/mozilla/TTS into dev
commit
5ee73c2bae
|
@ -4,13 +4,15 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class TorchSTFT():
|
||||
class TorchSTFT(nn.Module):
|
||||
def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
|
||||
""" Torch based STFT operation """
|
||||
super(TorchSTFT, self).__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.window = getattr(torch, window)(win_length)
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length),
|
||||
requires_grad=False)
|
||||
|
||||
def __call__(self, x):
|
||||
# B x D x T x 2
|
||||
|
@ -22,7 +24,8 @@ class TorchSTFT():
|
|||
center=True,
|
||||
pad_mode="reflect", # compatible with audio.py
|
||||
normalized=False,
|
||||
onesided=True)
|
||||
onesided=True,
|
||||
return_complex=False)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
return torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||
|
|
Loading…
Reference in New Issue