From 3c2d5a9e03040e081732a5e917464ddd74049c43 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 16 Nov 2023 10:57:06 +0100 Subject: [PATCH] Remove duplicate AudioProcessor code and fix ExtractTTSpectrogram.ipynb (#3230) * chore: remove unused argument * refactor(audio.processor): remove duplicate stft+griffin_lim * chore(audio.processor): remove unused compute_stft_paddings Same function available in numpy_transforms * refactor(audio.processor): remove duplicate db_to_amp * refactor(audio.processor): remove duplicate amp_to_db * refactor(audio.processor): remove duplicate linear_to_mel * refactor(audio.processor): remove duplicate mel_to_linear * refactor(audio.processor): remove duplicate build_mel_basis * refactor(audio.processor): remove duplicate stft_parameters * refactor(audio.processor): use pre-/deemphasis from numpy_transforms * refactor(audio.processor): use rms_volume_norm from numpy_transforms * chore(audio.processor): remove duplicate assert Already checked in numpy_transforms.compute_f0 * refactor(audio.processor): use find_endpoint from numpy_transforms * refactor(audio.processor): use trim_silence from numpy_transforms * refactor(audio.processor): use volume_norm from numpy_transforms * refactor(audio.processor): use load_wav from numpy_transforms * fix(bin.extract_tts_spectrograms): set quantization bits * fix(ExtractTTSpectrogram.ipynb): adapt to current TTS code Fixes #2447, #2574 * refactor(audio.processor): remove duplicate quantization methods --- TTS/bin/extract_tts_spectrograms.py | 11 +- TTS/utils/audio/numpy_transforms.py | 1 - TTS/utils/audio/processor.py | 314 ++++++--------------- TTS/vocoder/datasets/preprocess.py | 7 +- TTS/vocoder/datasets/wavernn_dataset.py | 6 +- TTS/vocoder/models/wavernn.py | 3 +- notebooks/ExtractTTSpectrogram.ipynb | 182 +++++------- tests/vocoder_tests/test_vocoder_losses.py | 3 +- 8 files changed, 177 insertions(+), 350 deletions(-) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 9eadee07..c6048626 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -15,6 +15,7 @@ from TTS.tts.models import setup_model from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import quantize from TTS.utils.generic_utils import count_parameters use_cuda = torch.cuda.is_available() @@ -159,7 +160,7 @@ def inference( def extract_spectrograms( - data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt" + data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt" ): model.eval() export_metadata = [] @@ -196,8 +197,8 @@ def extract_spectrograms( _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) # quantize and save wav - if quantized_wav: - wavq = ap.quantize(wav) + if quantize_bits > 0: + wavq = quantize(wav, quantize_bits) np.save(wavq_path, wavq) # save TTS mel @@ -263,7 +264,7 @@ def main(args): # pylint: disable=redefined-outer-name model, ap, args.output_path, - quantized_wav=args.quantized, + quantize_bits=args.quantize_bits, save_audio=args.save_audio, debug=args.debug, metada_name="metada.txt", @@ -277,7 +278,7 @@ if __name__ == "__main__": parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True) parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") - parser.add_argument("--quantized", action="store_true", help="Save quantized audio files") + parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero") parser.add_argument("--eval", type=bool, help="compute eval.", default=True) args = parser.parse_args() diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index b701e767..af88569f 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -201,7 +201,6 @@ def stft( def istft( *, y: np.ndarray = None, - fft_size: int = None, hop_length: int = None, win_length: int = None, window: str = "hann", diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py index 4ceb7da4..c53bad56 100644 --- a/TTS/utils/audio/processor.py +++ b/TTS/utils/audio/processor.py @@ -5,10 +5,26 @@ import librosa import numpy as np import scipy.io.wavfile import scipy.signal -import soundfile as sf from TTS.tts.utils.helpers import StandardScaler -from TTS.utils.audio.numpy_transforms import compute_f0 +from TTS.utils.audio.numpy_transforms import ( + amp_to_db, + build_mel_basis, + compute_f0, + db_to_amp, + deemphasis, + find_endpoint, + griffin_lim, + load_wav, + mel_to_spec, + millisec_to_length, + preemphasis, + rms_volume_norm, + spec_to_mel, + stft, + trim_silence, + volume_norm, +) # pylint: disable=too-many-public-methods @@ -200,7 +216,9 @@ class AudioProcessor(object): # setup stft parameters if hop_length is None: # compute stft parameters from given time values - self.hop_length, self.win_length = self._stft_parameters() + self.win_length, self.hop_length = millisec_to_length( + frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate + ) else: # use stft parameters from config file self.hop_length = hop_length @@ -215,8 +233,13 @@ class AudioProcessor(object): for key, value in members.items(): print(" | > {}:{}".format(key, value)) # create spectrogram utils - self.mel_basis = self._build_mel_basis() - self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) + self.mel_basis = build_mel_basis( + sample_rate=self.sample_rate, + fft_size=self.fft_size, + num_mels=self.num_mels, + mel_fmax=self.mel_fmax, + mel_fmin=self.mel_fmin, + ) # setup scaler if stats_path and signal_norm: mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) @@ -232,35 +255,6 @@ class AudioProcessor(object): return AudioProcessor(verbose=verbose, **config.audio) return AudioProcessor(verbose=verbose, **config) - ### setting up the parameters ### - def _build_mel_basis( - self, - ) -> np.ndarray: - """Build melspectrogram basis. - - Returns: - np.ndarray: melspectrogram basis. - """ - if self.mel_fmax is not None: - assert self.mel_fmax <= self.sample_rate // 2 - return librosa.filters.mel( - sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax - ) - - def _stft_parameters( - self, - ) -> Tuple[int, int]: - """Compute the real STFT parameters from the time values. - - Returns: - Tuple[int, int]: hop length and window length for STFT. - """ - factor = self.frame_length_ms / self.frame_shift_ms - assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" - hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) - win_length = int(hop_length * factor) - return hop_length, win_length - ### normalization ### def normalize(self, S: np.ndarray) -> np.ndarray: """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` @@ -386,31 +380,6 @@ class AudioProcessor(object): self.linear_scaler = StandardScaler() self.linear_scaler.set_stats(linear_mean, linear_std) - ### DB and AMP conversion ### - # pylint: disable=no-self-use - def _amp_to_db(self, x: np.ndarray) -> np.ndarray: - """Convert amplitude values to decibels. - - Args: - x (np.ndarray): Amplitude spectrogram. - - Returns: - np.ndarray: Decibels spectrogram. - """ - return self.spec_gain * _log(np.maximum(1e-5, x), self.base) - - # pylint: disable=no-self-use - def _db_to_amp(self, x: np.ndarray) -> np.ndarray: - """Convert decibels spectrogram to amplitude spectrogram. - - Args: - x (np.ndarray): Decibels spectrogram. - - Returns: - np.ndarray: Amplitude spectrogram. - """ - return _exp(x / self.spec_gain, self.base) - ### Preemphasis ### def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. @@ -424,32 +393,13 @@ class AudioProcessor(object): Returns: np.ndarray: Decorrelated audio signal. """ - if self.preemphasis == 0: - raise RuntimeError(" [!] Preemphasis is set 0.0.") - return scipy.signal.lfilter([1, -self.preemphasis], [1], x) + return preemphasis(x=x, coef=self.preemphasis) def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: """Reverse pre-emphasis.""" - if self.preemphasis == 0: - raise RuntimeError(" [!] Preemphasis is set 0.0.") - return scipy.signal.lfilter([1], [1, -self.preemphasis], x) + return deemphasis(x=x, coef=self.preemphasis) ### SPECTROGRAMs ### - def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray: - """Project a full scale spectrogram to a melspectrogram. - - Args: - spectrogram (np.ndarray): Full scale spectrogram. - - Returns: - np.ndarray: Melspectrogram - """ - return np.dot(self.mel_basis, spectrogram) - - def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray: - """Convert a melspectrogram to full scale spectrogram.""" - return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec)) - def spectrogram(self, y: np.ndarray) -> np.ndarray: """Compute a spectrogram from a waveform. @@ -460,11 +410,16 @@ class AudioProcessor(object): np.ndarray: Spectrogram. """ if self.preemphasis != 0: - D = self._stft(self.apply_preemphasis(y)) - else: - D = self._stft(y) + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) if self.do_amp_to_db_linear: - S = self._amp_to_db(np.abs(D)) + S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base) else: S = np.abs(D) return self.normalize(S).astype(np.float32) @@ -472,32 +427,35 @@ class AudioProcessor(object): def melspectrogram(self, y: np.ndarray) -> np.ndarray: """Compute a melspectrogram from a waveform.""" if self.preemphasis != 0: - D = self._stft(self.apply_preemphasis(y)) - else: - D = self._stft(y) + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) + S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis) if self.do_amp_to_db_mel: - S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - else: - S = self._linear_to_mel(np.abs(D)) + S = amp_to_db(x=S, gain=self.spec_gain, base=self.base) + return self.normalize(S).astype(np.float32) def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" S = self.denormalize(spectrogram) - S = self._db_to_amp(S) + S = db_to_amp(x=S, gain=self.spec_gain, base=self.base) # Reconstruct phase - if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) - return self._griffin_lim(S**self.power) + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" D = self.denormalize(mel_spectrogram) - S = self._db_to_amp(D) - S = self._mel_to_linear(S) # Convert back to linear - if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) - return self._griffin_lim(S**self.power) + S = db_to_amp(x=D, gain=self.spec_gain, base=self.base) + S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: """Convert a full scale linear spectrogram output of a network to a melspectrogram. @@ -509,60 +467,22 @@ class AudioProcessor(object): np.ndarray: Normalized melspectrogram. """ S = self.denormalize(linear_spec) - S = self._db_to_amp(S) - S = self._linear_to_mel(np.abs(S)) - S = self._amp_to_db(S) + S = db_to_amp(x=S, gain=self.spec_gain, base=self.base) + S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis) + S = amp_to_db(x=S, gain=self.spec_gain, base=self.base) mel = self.normalize(S) return mel - ### STFT and ISTFT ### - def _stft(self, y: np.ndarray) -> np.ndarray: - """Librosa STFT wrapper. - - Args: - y (np.ndarray): Audio signal. - - Returns: - np.ndarray: Complex number array. - """ - return librosa.stft( - y=y, - n_fft=self.fft_size, + def _griffin_lim(self, S): + return griffin_lim( + spec=S, + num_iter=self.griffin_lim_iters, hop_length=self.hop_length, win_length=self.win_length, + fft_size=self.fft_size, pad_mode=self.stft_pad_mode, - window="hann", - center=True, ) - def _istft(self, y: np.ndarray) -> np.ndarray: - """Librosa iSTFT wrapper.""" - return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length) - - def _griffin_lim(self, S): - angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) - try: - S_complex = np.abs(S).astype(np.complex) - except AttributeError: # np.complex is deprecated since numpy 1.20.0 - S_complex = np.abs(S).astype(complex) - y = self._istft(S_complex * angles) - if not np.isfinite(y).all(): - print(" [!] Waveform is not finite everywhere. Skipping the GL.") - return np.array([0.0]) - for _ in range(self.griffin_lim_iters): - angles = np.exp(1j * np.angle(self._stft(y))) - y = self._istft(S_complex * angles) - return y - - def compute_stft_paddings(self, x, pad_sides=1): - """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding - (first and final frames)""" - assert pad_sides in (1, 2) - pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0] - if pad_sides == 1: - return 0, pad - return pad // 2, pad // 2 + pad % 2 - def compute_f0(self, x: np.ndarray) -> np.ndarray: """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. @@ -581,8 +501,6 @@ class AudioProcessor(object): >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] >>> pitch = ap.compute_f0(wav) """ - assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`." - assert self.pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`." # align F0 length to the spectrogram length if len(x) % self.hop_length == 0: x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode) @@ -612,21 +530,24 @@ class AudioProcessor(object): Returns: int: Last point without silence. """ - window_length = int(self.sample_rate * min_silence_sec) - hop_length = int(window_length / 4) - threshold = self._db_to_amp(-self.trim_db) - for x in range(hop_length, len(wav) - window_length, hop_length): - if np.max(wav[x : x + window_length]) < threshold: - return x + hop_length - return len(wav) + return find_endpoint( + wav=wav, + trim_db=self.trim_db, + sample_rate=self.sample_rate, + min_silence_sec=min_silence_sec, + gain=self.spec_gain, + base=self.base, + ) def trim_silence(self, wav): """Trim silent parts with a threshold and 0.01 sec margin""" - margin = int(self.sample_rate * 0.01) - wav = wav[margin:-margin] - return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[ - 0 - ] + return trim_silence( + wav=wav, + sample_rate=self.sample_rate, + trim_db=self.trim_db, + win_length=self.win_length, + hop_length=self.hop_length, + ) @staticmethod def sound_norm(x: np.ndarray) -> np.ndarray: @@ -638,13 +559,7 @@ class AudioProcessor(object): Returns: np.ndarray: Volume normalized waveform. """ - return x / abs(x).max() * 0.95 - - @staticmethod - def _rms_norm(wav, db_level=-27): - r = 10 ** (db_level / 20) - a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2)) - return wav * a + return volume_norm(x=x) def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: """Normalize the volume based on RMS of the signal. @@ -657,9 +572,7 @@ class AudioProcessor(object): """ if db_level is None: db_level = self.db_level - assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" - wav = self._rms_norm(x, db_level) - return wav + return rms_volume_norm(x=x, db_level=db_level) ### save and load ### def load_wav(self, filename: str, sr: int = None) -> np.ndarray: @@ -674,15 +587,10 @@ class AudioProcessor(object): Returns: np.ndarray: Loaded waveform. """ - if self.resample: - # loading with resampling. It is significantly slower. - x, sr = librosa.load(filename, sr=self.sample_rate) - elif sr is None: - # SF is faster than librosa for loading files - x, sr = sf.read(filename) - assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr) + if sr is not None: + x = load_wav(filename=filename, sample_rate=sr, resample=True) else: - x, sr = librosa.load(filename, sr=sr) + x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample) if self.do_trim_silence: try: x = self.trim_silence(x) @@ -723,55 +631,3 @@ class AudioProcessor(object): filename (str): Path to the wav file. """ return librosa.get_duration(filename=filename) - - @staticmethod - def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: - mu = 2**qc - 1 - # wav_abs = np.minimum(np.abs(wav), 1.0) - signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) - # Quantize signal to the specified number of levels. - signal = (signal + 1) / 2 * mu + 0.5 - return np.floor( - signal, - ) - - @staticmethod - def mulaw_decode(wav, qc): - """Recovers waveform from quantized values.""" - mu = 2**qc - 1 - x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) - return x - - @staticmethod - def encode_16bits(x): - return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) - - @staticmethod - def quantize(x: np.ndarray, bits: int) -> np.ndarray: - """Quantize a waveform to a given number of bits. - - Args: - x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. - bits (int): Number of quantization bits. - - Returns: - np.ndarray: Quantized waveform. - """ - return (x + 1.0) * (2**bits - 1) / 2 - - @staticmethod - def dequantize(x, bits): - """Dequantize a waveform from the given number of bits.""" - return 2 * x / (2**bits - 1) - 1 - - -def _log(x, base): - if base == 10: - return np.log10(x) - return np.log(x) - - -def _exp(x, base): - if base == 10: - return np.power(10, x) - return np.exp(x) diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py index 0f69b812..503bb04b 100644 --- a/TTS/vocoder/datasets/preprocess.py +++ b/TTS/vocoder/datasets/preprocess.py @@ -7,6 +7,7 @@ from coqpit import Coqpit from tqdm import tqdm from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): @@ -29,7 +30,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): mel = ap.melspectrogram(y) np.save(mel_path, mel) if isinstance(config.mode, int): - quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode) + quant = ( + mulaw_encode(wav=y, mulaw_qc=config.mode) + if config.model_args.mulaw + else quantize(x=y, quantize_bits=config.mode) + ) np.save(quant_path, quant) diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index c3907964..a67c5b31 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -2,6 +2,8 @@ import numpy as np import torch from torch.utils.data import Dataset +from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize + class WaveRNNDataset(Dataset): """ @@ -66,7 +68,9 @@ class WaveRNNDataset(Dataset): x_input = audio elif isinstance(self.mode, int): x_input = ( - self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode) + mulaw_encode(wav=audio, mulaw_qc=self.mode) + if self.mulaw + else quantize(x=audio, quantize_bits=self.mode) ) else: raise RuntimeError("Unknown dataset mode - ", self.mode) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 903f4b7e..7f74ba3e 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import mulaw_decode from TTS.utils.io import load_fsspec from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.layers.losses import WaveRNNLoss @@ -399,7 +400,7 @@ class Wavernn(BaseVocoder): output = output[0] if self.args.mulaw and isinstance(self.args.mode, int): - output = AudioProcessor.mulaw_decode(output, self.args.mode) + output = mulaw_decode(wav=output, mulaw_qc=self.args.mode) # Fade-out at the end to avoid signal cutting out suddenly fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length) diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb index 9acc9929..0ec5f167 100644 --- a/notebooks/ExtractTTSpectrogram.ipynb +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -13,23 +13,28 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "import sys\n", - "import torch\n", "import importlib\n", - "import numpy as np\n", - "from tqdm import tqdm\n", - "from torch.utils.data import DataLoader\n", - "import soundfile as sf\n", + "import os\n", "import pickle\n", + "\n", + "import numpy as np\n", + "import soundfile as sf\n", + "import torch\n", + "from matplotlib import pylab as plt\n", + "from torch.utils.data import DataLoader\n", + "from tqdm import tqdm\n", + "\n", + "from TTS.config import load_config\n", + "from TTS.tts.configs.shared_configs import BaseDatasetConfig\n", + "from TTS.tts.datasets import load_tts_samples\n", "from TTS.tts.datasets.dataset import TTSDataset\n", "from TTS.tts.layers.losses import L1LossMasked\n", - "from TTS.utils.audio import AudioProcessor\n", - "from TTS.config import load_config\n", - "from TTS.tts.utils.visual import plot_spectrogram\n", - "from TTS.tts.utils.helpers import sequence_mask\n", "from TTS.tts.models import setup_model\n", - "from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n", + "from TTS.tts.utils.helpers import sequence_mask\n", + "from TTS.tts.utils.text.tokenizer import TTSTokenizer\n", + "from TTS.tts.utils.visual import plot_spectrogram\n", + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.utils.audio.numpy_transforms import quantize\n", "\n", "%matplotlib inline\n", "\n", @@ -49,11 +54,9 @@ " file_name = wav_file.split('.')[0]\n", " os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n", " os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n", - " os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n", " wavq_path = os.path.join(out_path, \"quant\", file_name)\n", " mel_path = os.path.join(out_path, \"mel\", file_name)\n", - " wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n", - " return file_name, wavq_path, mel_path, wav_path" + " return file_name, wavq_path, mel_path" ] }, { @@ -65,14 +68,14 @@ "# Paths and configurations\n", "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n", "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n", + "PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n", "DATASET = \"ljspeech\"\n", "METADATA_FILE = \"metadata.csv\"\n", "CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n", "MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n", "BATCH_SIZE = 32\n", "\n", - "QUANTIZED_WAV = False\n", - "QUANTIZE_BIT = None\n", + "QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n", "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n", "\n", "# Check CUDA availability\n", @@ -80,10 +83,10 @@ "print(\" > CUDA enabled: \", use_cuda)\n", "\n", "# Load the configuration\n", + "dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n", "C = load_config(CONFIG_PATH)\n", "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n", - "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n", - "print(C['r'])" + "ap = AudioProcessor(**C.audio)" ] }, { @@ -92,12 +95,10 @@ "metadata": {}, "outputs": [], "source": [ - "# If the vocabulary was passed, replace the default\n", - "if 'characters' in C and C['characters']:\n", - " symbols, phonemes = make_symbols(**C.characters)\n", + "# Initialize the tokenizer\n", + "tokenizer, C = TTSTokenizer.init_from_config(C)\n", "\n", "# Load the model\n", - "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", "# TODO: multiple speakers\n", "model = setup_model(C)\n", "model.load_checkpoint(C, MODEL_FILE, eval=True)" @@ -109,42 +110,21 @@ "metadata": {}, "outputs": [], "source": [ - "# Load the preprocessor based on the dataset\n", - "preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n", - "preprocessor = getattr(preprocessor, DATASET.lower())\n", - "meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n", + "# Load data instances\n", + "meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n", + "meta_data = meta_data_train + meta_data_eval\n", + "\n", "dataset = TTSDataset(\n", - " C,\n", - " C.text_cleaner,\n", - " False,\n", - " ap,\n", - " meta_data,\n", - " characters=C.get('characters', None),\n", - " use_phonemes=C.use_phonemes,\n", - " phoneme_cache_path=C.phoneme_cache_path,\n", - " enable_eos_bos=C.enable_eos_bos_chars,\n", + " outputs_per_step=C[\"r\"],\n", + " compute_linear_spec=False,\n", + " ap=ap,\n", + " samples=meta_data,\n", + " tokenizer=tokenizer,\n", + " phoneme_cache_path=PHONEME_CACHE_PATH,\n", ")\n", "loader = DataLoader(\n", " dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Initialize lists for storing results\n", - "file_idxs = []\n", - "metadata = []\n", - "losses = []\n", - "postnet_losses = []\n", - "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n", - "\n", - "# Create log file\n", - "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n", - "log_file = open(log_file_path, \"w\")" + ")" ] }, { @@ -160,26 +140,33 @@ "metadata": {}, "outputs": [], "source": [ + "# Initialize lists for storing results\n", + "file_idxs = []\n", + "metadata = []\n", + "losses = []\n", + "postnet_losses = []\n", + "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n", + "\n", "# Start processing with a progress bar\n", - "with torch.no_grad():\n", + "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n", + "with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n", " for data in tqdm(loader, desc=\"Processing\"):\n", " try:\n", - " # setup input data\n", - " text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n", - "\n", " # dispatch data to GPU\n", " if use_cuda:\n", - " text_input = text_input.cuda()\n", - " text_lengths = text_lengths.cuda()\n", - " mel_input = mel_input.cuda()\n", - " mel_lengths = mel_lengths.cuda()\n", + " data[\"token_id\"] = data[\"token_id\"].cuda()\n", + " data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n", + " data[\"mel\"] = data[\"mel\"].cuda()\n", + " data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n", "\n", - " mask = sequence_mask(text_lengths)\n", - " mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n", + " mask = sequence_mask(data[\"token_id_lengths\"])\n", + " outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n", + " mel_outputs = outputs[\"decoder_outputs\"]\n", + " postnet_outputs = outputs[\"model_outputs\"]\n", "\n", " # compute loss\n", - " loss = criterion(mel_outputs, mel_input, mel_lengths)\n", - " loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n", + " loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n", + " loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n", " losses.append(loss.item())\n", " postnet_losses.append(loss_postnet.item())\n", "\n", @@ -193,28 +180,27 @@ " postnet_outputs = torch.stack(mel_specs)\n", " elif C.model == \"Tacotron2\":\n", " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n", - " alignments = alignments.detach().cpu().numpy()\n", + " alignments = outputs[\"alignments\"].detach().cpu().numpy()\n", "\n", " if not DRY_RUN:\n", - " for idx in range(text_input.shape[0]):\n", - " wav_file_path = item_idx[idx]\n", + " for idx in range(data[\"token_id\"].shape[0]):\n", + " wav_file_path = data[\"item_idxs\"][idx]\n", " wav = ap.load_wav(wav_file_path)\n", - " file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", + " file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n", " file_idxs.append(file_name)\n", "\n", " # quantize and save wav\n", - " if QUANTIZED_WAV:\n", - " wavq = ap.quantize(wav)\n", + " if QUANTIZE_BITS > 0:\n", + " wavq = quantize(wav, QUANTIZE_BITS)\n", " np.save(wavq_path, wavq)\n", "\n", " # save TTS mel\n", " mel = postnet_outputs[idx]\n", - " mel_length = mel_lengths[idx]\n", + " mel_length = data[\"mel_lengths\"][idx]\n", " mel = mel[:mel_length, :].T\n", " np.save(mel_path, mel)\n", "\n", " metadata.append([wav_file_path, mel_path])\n", - "\n", " except Exception as e:\n", " log_file.write(f\"Error processing data: {str(e)}\\n\")\n", "\n", @@ -224,35 +210,20 @@ " log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n", " log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n", "\n", - "# Close the log file\n", - "log_file.close()\n", - "\n", "# For wavernn\n", "if not DRY_RUN:\n", " pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n", "\n", "# For pwgan\n", "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", - " for data in metadata:\n", - " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n", + " for wav_file_path, mel_path in metadata:\n", + " f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n", "\n", "# Print mean losses\n", "print(f\"Mean Loss: {mean_loss}\")\n", "print(f\"Mean Postnet Loss: {mean_postnet_loss}\")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# for pwgan\n", - "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", - " for data in metadata:\n", - " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -267,7 +238,7 @@ "outputs": [], "source": [ "idx = 1\n", - "ap.melspectrogram(ap.load_wav(item_idx[idx])).shape" + "ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape" ] }, { @@ -276,10 +247,9 @@ "metadata": {}, "outputs": [], "source": [ - "import soundfile as sf\n", - "wav, sr = sf.read(item_idx[idx])\n", - "mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n", - "mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n", + "wav, sr = sf.read(data[\"item_idxs\"][idx])\n", + "mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n", + "mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n", "mel_truth = ap.melspectrogram(wav)\n", "print(mel_truth.shape)" ] @@ -291,7 +261,7 @@ "outputs": [], "source": [ "# plot posnet output\n", - "print(mel_postnet[:mel_lengths[idx], :].shape)\n", + "print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n", "plot_spectrogram(mel_postnet, ap)" ] }, @@ -324,10 +294,9 @@ "outputs": [], "source": [ "# postnet, decoder diff\n", - "from matplotlib import pylab as plt\n", "mel_diff = mel_decoder - mel_postnet\n", "plt.figure(figsize=(16, 10))\n", - "plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n", + "plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n", "plt.colorbar()\n", "plt.tight_layout()" ] @@ -339,10 +308,9 @@ "outputs": [], "source": [ "# PLOT GT SPECTROGRAM diff\n", - "from matplotlib import pylab as plt\n", "mel_diff2 = mel_truth.T - mel_decoder\n", "plt.figure(figsize=(16, 10))\n", - "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", + "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n", "plt.colorbar()\n", "plt.tight_layout()" ] @@ -354,21 +322,13 @@ "outputs": [], "source": [ "# PLOT GT SPECTROGRAM diff\n", - "from matplotlib import pylab as plt\n", "mel = postnet_outputs[idx]\n", "mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n", "plt.figure(figsize=(16, 10))\n", - "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", + "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n", "plt.colorbar()\n", "plt.tight_layout()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/tests/vocoder_tests/test_vocoder_losses.py b/tests/vocoder_tests/test_vocoder_losses.py index 2a35aa2e..95501c2d 100644 --- a/tests/vocoder_tests/test_vocoder_losses.py +++ b/tests/vocoder_tests/test_vocoder_losses.py @@ -5,6 +5,7 @@ import torch from tests import get_tests_input_path, get_tests_output_path, get_tests_path from TTS.config import BaseAudioConfig from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import stft from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT TESTS_PATH = get_tests_path() @@ -21,7 +22,7 @@ def test_torch_stft(): torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length) # librosa stft wav = ap.load_wav(WAV_FILE) - M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access + M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length)) # torch stft wav = torch.from_numpy(wav[None, :]).float() M_torch = torch_stft(wav)