mirror of https://github.com/coqui-ai/TTS.git
Move `TorchSTFT` to `utils.audio`
parent
5b89cb4fec
commit
d700845b10
|
@ -3,12 +3,89 @@ import numpy as np
|
|||
import scipy.io.wavfile
|
||||
import scipy.signal
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.utils.data import StandardScaler
|
||||
|
||||
# import pyworld as pw
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""TODO: Merge this with audio.py"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
pad_wav=False,
|
||||
window="hann_window",
|
||||
sample_rate=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.pad_wav = pad_wav
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.n_mels = n_mels
|
||||
self.use_mel = use_mel
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
if use_mel:
|
||||
self._build_mel_basis()
|
||||
|
||||
def __call__(self, x):
|
||||
"""Compute spectrogram frames by torch based stft.
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform
|
||||
|
||||
Returns:
|
||||
Tensor: spectrogram frames.
|
||||
|
||||
Shapes:
|
||||
x: [B x T] or [B x 1 x T]
|
||||
"""
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if self.pad_wav:
|
||||
padding = int((self.n_fft - self.hop_length) / 2)
|
||||
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
||||
# B x D x T x 2
|
||||
o = torch.stft(
|
||||
x.squeeze(1),
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
pad_mode="reflect", # compatible with audio.py
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
return S
|
||||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class AudioProcessor(object):
|
||||
"""Audio Processor for TTS used by all the data pipelines.
|
||||
|
|
|
@ -1,88 +1,12 @@
|
|||
from typing import Dict, Union
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""TODO: Merge this with audio.py"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
pad_wav=False,
|
||||
window="hann_window",
|
||||
sample_rate=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.pad_wav = pad_wav
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.n_mels = n_mels
|
||||
self.use_mel = use_mel
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
if use_mel:
|
||||
self._build_mel_basis()
|
||||
|
||||
def __call__(self, x):
|
||||
"""Compute spectrogram frames by torch based stft.
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform
|
||||
|
||||
Returns:
|
||||
Tensor: spectrogram frames.
|
||||
|
||||
Shapes:
|
||||
x: [B x T] or [B x 1 x T]
|
||||
"""
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if self.pad_wav:
|
||||
padding = int((self.n_fft - self.hop_length) / 2)
|
||||
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
||||
# B x D x T x 2
|
||||
o = torch.stft(
|
||||
x.squeeze(1),
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
pad_mode="reflect", # compatible with audio.py
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
return S
|
||||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
|
||||
#################################
|
||||
# GENERATOR LOSSES
|
||||
#################################
|
||||
|
@ -275,7 +199,7 @@ def _apply_D_loss(scores_fake, scores_real, loss_func):
|
|||
loss += total_loss
|
||||
real_loss += real_loss
|
||||
fake_loss += fake_loss
|
||||
# normalize loss values with number of scales
|
||||
# normalize loss values with number of scales (discriminators)
|
||||
loss /= len(scores_fake)
|
||||
real_loss /= len(scores_real)
|
||||
fake_loss /= len(scores_fake)
|
||||
|
|
Loading…
Reference in New Issue