From d700845b109cc140ce4d50645b94d4eb19919a3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 21 Jun 2021 16:50:37 +0200 Subject: [PATCH] Move `TorchSTFT` to `utils.audio` --- TTS/utils/audio.py | 77 ++++++++++++++++++++++++++++++++++ TTS/vocoder/layers/losses.py | 80 +----------------------------------- 2 files changed, 79 insertions(+), 78 deletions(-) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 222b4c74..e1913e98 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -3,12 +3,89 @@ import numpy as np import scipy.io.wavfile import scipy.signal import soundfile as sf +import torch +from torch import nn from TTS.tts.utils.data import StandardScaler # import pyworld as pw +class TorchSTFT(nn.Module): # pylint: disable=abstract-method + """TODO: Merge this with audio.py""" + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.pad_wav = pad_wav + self.sample_rate = sample_rate + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.n_mels = n_mels + self.use_mel = use_mel + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) + self.mel_basis = None + if use_mel: + self._build_mel_basis() + + def __call__(self, x): + """Compute spectrogram frames by torch based stft. + + Args: + x (Tensor): input waveform + + Returns: + Tensor: spectrogram frames. + + Shapes: + x: [B x T] or [B x 1 x T] + """ + if x.ndim == 2: + x = x.unsqueeze(1) + if self.pad_wav: + padding = int((self.n_fft - self.hop_length) / 2) + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") + # B x D x T x 2 + o = torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=False, + onesided=True, + return_complex=False, + ) + M = o[:, :, :, 0] + P = o[:, :, :, 1] + S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) + if self.use_mel: + S = torch.matmul(self.mel_basis.to(x), S) + return S + + def _build_mel_basis(self): + mel_basis = librosa.filters.mel( + self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + ) + self.mel_basis = torch.from_numpy(mel_basis).float() + + # pylint: disable=too-many-public-methods class AudioProcessor(object): """Audio Processor for TTS used by all the data pipelines. diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index 9acdeea1..848e292b 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -1,88 +1,12 @@ from typing import Dict, Union -import librosa import torch from torch import nn from torch.nn import functional as F +from TTS.utils.audio import TorchSTFT from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss - -class TorchSTFT(nn.Module): # pylint: disable=abstract-method - """TODO: Merge this with audio.py""" - - def __init__( - self, - n_fft, - hop_length, - win_length, - pad_wav=False, - window="hann_window", - sample_rate=None, - mel_fmin=0, - mel_fmax=None, - n_mels=80, - use_mel=False, - ): - super().__init__() - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.pad_wav = pad_wav - self.sample_rate = sample_rate - self.mel_fmin = mel_fmin - self.mel_fmax = mel_fmax - self.n_mels = n_mels - self.use_mel = use_mel - self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) - self.mel_basis = None - if use_mel: - self._build_mel_basis() - - def __call__(self, x): - """Compute spectrogram frames by torch based stft. - - Args: - x (Tensor): input waveform - - Returns: - Tensor: spectrogram frames. - - Shapes: - x: [B x T] or [B x 1 x T] - """ - if x.ndim == 2: - x = x.unsqueeze(1) - if self.pad_wav: - padding = int((self.n_fft - self.hop_length) / 2) - x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") - # B x D x T x 2 - o = torch.stft( - x.squeeze(1), - self.n_fft, - self.hop_length, - self.win_length, - self.window, - center=True, - pad_mode="reflect", # compatible with audio.py - normalized=False, - onesided=True, - return_complex=False, - ) - M = o[:, :, :, 0] - P = o[:, :, :, 1] - S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) - if self.use_mel: - S = torch.matmul(self.mel_basis.to(x), S) - return S - - def _build_mel_basis(self): - mel_basis = librosa.filters.mel( - self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax - ) - self.mel_basis = torch.from_numpy(mel_basis).float() - - ################################# # GENERATOR LOSSES ################################# @@ -275,7 +199,7 @@ def _apply_D_loss(scores_fake, scores_real, loss_func): loss += total_loss real_loss += real_loss fake_loss += fake_loss - # normalize loss values with number of scales + # normalize loss values with number of scales (discriminators) loss /= len(scores_fake) real_loss /= len(scores_real) fake_loss /= len(scores_fake)