Remove duplicate AudioProcessor code and fix ExtractTTSpectrogram.ipynb (#3230)

* chore: remove unused argument

* refactor(audio.processor): remove duplicate stft+griffin_lim

* chore(audio.processor): remove unused compute_stft_paddings

Same function available in numpy_transforms

* refactor(audio.processor): remove duplicate db_to_amp

* refactor(audio.processor): remove duplicate amp_to_db

* refactor(audio.processor): remove duplicate linear_to_mel

* refactor(audio.processor): remove duplicate mel_to_linear

* refactor(audio.processor): remove duplicate build_mel_basis

* refactor(audio.processor): remove duplicate stft_parameters

* refactor(audio.processor): use pre-/deemphasis from numpy_transforms

* refactor(audio.processor): use rms_volume_norm from numpy_transforms

* chore(audio.processor): remove duplicate assert

Already checked in numpy_transforms.compute_f0

* refactor(audio.processor): use find_endpoint from numpy_transforms

* refactor(audio.processor): use trim_silence from numpy_transforms

* refactor(audio.processor): use volume_norm from numpy_transforms

* refactor(audio.processor): use load_wav from numpy_transforms

* fix(bin.extract_tts_spectrograms): set quantization bits

* fix(ExtractTTSpectrogram.ipynb): adapt to current TTS code

Fixes #2447, #2574

* refactor(audio.processor): remove duplicate quantization methods
pull/3238/head
Enno Hermann 2023-11-16 10:57:06 +01:00 committed by GitHub
parent 88630c60e5
commit 3c2d5a9e03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 177 additions and 350 deletions

View File

@ -15,6 +15,7 @@ from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import quantize
from TTS.utils.generic_utils import count_parameters from TTS.utils.generic_utils import count_parameters
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
@ -159,7 +160,7 @@ def inference(
def extract_spectrograms( def extract_spectrograms(
data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt" data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt"
): ):
model.eval() model.eval()
export_metadata = [] export_metadata = []
@ -196,8 +197,8 @@ def extract_spectrograms(
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
# quantize and save wav # quantize and save wav
if quantized_wav: if quantize_bits > 0:
wavq = ap.quantize(wav) wavq = quantize(wav, quantize_bits)
np.save(wavq_path, wavq) np.save(wavq_path, wavq)
# save TTS mel # save TTS mel
@ -263,7 +264,7 @@ def main(args): # pylint: disable=redefined-outer-name
model, model,
ap, ap,
args.output_path, args.output_path,
quantized_wav=args.quantized, quantize_bits=args.quantize_bits,
save_audio=args.save_audio, save_audio=args.save_audio,
debug=args.debug, debug=args.debug,
metada_name="metada.txt", metada_name="metada.txt",
@ -277,7 +278,7 @@ if __name__ == "__main__":
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True) parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files") parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
parser.add_argument("--eval", type=bool, help="compute eval.", default=True) parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
args = parser.parse_args() args = parser.parse_args()

View File

@ -201,7 +201,6 @@ def stft(
def istft( def istft(
*, *,
y: np.ndarray = None, y: np.ndarray = None,
fft_size: int = None,
hop_length: int = None, hop_length: int = None,
win_length: int = None, win_length: int = None,
window: str = "hann", window: str = "hann",

View File

@ -5,10 +5,26 @@ import librosa
import numpy as np import numpy as np
import scipy.io.wavfile import scipy.io.wavfile
import scipy.signal import scipy.signal
import soundfile as sf
from TTS.tts.utils.helpers import StandardScaler from TTS.tts.utils.helpers import StandardScaler
from TTS.utils.audio.numpy_transforms import compute_f0 from TTS.utils.audio.numpy_transforms import (
amp_to_db,
build_mel_basis,
compute_f0,
db_to_amp,
deemphasis,
find_endpoint,
griffin_lim,
load_wav,
mel_to_spec,
millisec_to_length,
preemphasis,
rms_volume_norm,
spec_to_mel,
stft,
trim_silence,
volume_norm,
)
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
@ -200,7 +216,9 @@ class AudioProcessor(object):
# setup stft parameters # setup stft parameters
if hop_length is None: if hop_length is None:
# compute stft parameters from given time values # compute stft parameters from given time values
self.hop_length, self.win_length = self._stft_parameters() self.win_length, self.hop_length = millisec_to_length(
frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate
)
else: else:
# use stft parameters from config file # use stft parameters from config file
self.hop_length = hop_length self.hop_length = hop_length
@ -215,8 +233,13 @@ class AudioProcessor(object):
for key, value in members.items(): for key, value in members.items():
print(" | > {}:{}".format(key, value)) print(" | > {}:{}".format(key, value))
# create spectrogram utils # create spectrogram utils
self.mel_basis = self._build_mel_basis() self.mel_basis = build_mel_basis(
self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) sample_rate=self.sample_rate,
fft_size=self.fft_size,
num_mels=self.num_mels,
mel_fmax=self.mel_fmax,
mel_fmin=self.mel_fmin,
)
# setup scaler # setup scaler
if stats_path and signal_norm: if stats_path and signal_norm:
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
@ -232,35 +255,6 @@ class AudioProcessor(object):
return AudioProcessor(verbose=verbose, **config.audio) return AudioProcessor(verbose=verbose, **config.audio)
return AudioProcessor(verbose=verbose, **config) return AudioProcessor(verbose=verbose, **config)
### setting up the parameters ###
def _build_mel_basis(
self,
) -> np.ndarray:
"""Build melspectrogram basis.
Returns:
np.ndarray: melspectrogram basis.
"""
if self.mel_fmax is not None:
assert self.mel_fmax <= self.sample_rate // 2
return librosa.filters.mel(
sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
)
def _stft_parameters(
self,
) -> Tuple[int, int]:
"""Compute the real STFT parameters from the time values.
Returns:
Tuple[int, int]: hop length and window length for STFT.
"""
factor = self.frame_length_ms / self.frame_shift_ms
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(hop_length * factor)
return hop_length, win_length
### normalization ### ### normalization ###
def normalize(self, S: np.ndarray) -> np.ndarray: def normalize(self, S: np.ndarray) -> np.ndarray:
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
@ -386,31 +380,6 @@ class AudioProcessor(object):
self.linear_scaler = StandardScaler() self.linear_scaler = StandardScaler()
self.linear_scaler.set_stats(linear_mean, linear_std) self.linear_scaler.set_stats(linear_mean, linear_std)
### DB and AMP conversion ###
# pylint: disable=no-self-use
def _amp_to_db(self, x: np.ndarray) -> np.ndarray:
"""Convert amplitude values to decibels.
Args:
x (np.ndarray): Amplitude spectrogram.
Returns:
np.ndarray: Decibels spectrogram.
"""
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
# pylint: disable=no-self-use
def _db_to_amp(self, x: np.ndarray) -> np.ndarray:
"""Convert decibels spectrogram to amplitude spectrogram.
Args:
x (np.ndarray): Decibels spectrogram.
Returns:
np.ndarray: Amplitude spectrogram.
"""
return _exp(x / self.spec_gain, self.base)
### Preemphasis ### ### Preemphasis ###
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
@ -424,32 +393,13 @@ class AudioProcessor(object):
Returns: Returns:
np.ndarray: Decorrelated audio signal. np.ndarray: Decorrelated audio signal.
""" """
if self.preemphasis == 0: return preemphasis(x=x, coef=self.preemphasis)
raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1, -self.preemphasis], [1], x)
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
"""Reverse pre-emphasis.""" """Reverse pre-emphasis."""
if self.preemphasis == 0: return deemphasis(x=x, coef=self.preemphasis)
raise RuntimeError(" [!] Preemphasis is set 0.0.")
return scipy.signal.lfilter([1], [1, -self.preemphasis], x)
### SPECTROGRAMs ### ### SPECTROGRAMs ###
def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray:
"""Project a full scale spectrogram to a melspectrogram.
Args:
spectrogram (np.ndarray): Full scale spectrogram.
Returns:
np.ndarray: Melspectrogram
"""
return np.dot(self.mel_basis, spectrogram)
def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray:
"""Convert a melspectrogram to full scale spectrogram."""
return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec))
def spectrogram(self, y: np.ndarray) -> np.ndarray: def spectrogram(self, y: np.ndarray) -> np.ndarray:
"""Compute a spectrogram from a waveform. """Compute a spectrogram from a waveform.
@ -460,11 +410,16 @@ class AudioProcessor(object):
np.ndarray: Spectrogram. np.ndarray: Spectrogram.
""" """
if self.preemphasis != 0: if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y)) y = self.apply_preemphasis(y)
else: D = stft(
D = self._stft(y) y=y,
fft_size=self.fft_size,
hop_length=self.hop_length,
win_length=self.win_length,
pad_mode=self.stft_pad_mode,
)
if self.do_amp_to_db_linear: if self.do_amp_to_db_linear:
S = self._amp_to_db(np.abs(D)) S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base)
else: else:
S = np.abs(D) S = np.abs(D)
return self.normalize(S).astype(np.float32) return self.normalize(S).astype(np.float32)
@ -472,32 +427,35 @@ class AudioProcessor(object):
def melspectrogram(self, y: np.ndarray) -> np.ndarray: def melspectrogram(self, y: np.ndarray) -> np.ndarray:
"""Compute a melspectrogram from a waveform.""" """Compute a melspectrogram from a waveform."""
if self.preemphasis != 0: if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y)) y = self.apply_preemphasis(y)
else: D = stft(
D = self._stft(y) y=y,
fft_size=self.fft_size,
hop_length=self.hop_length,
win_length=self.win_length,
pad_mode=self.stft_pad_mode,
)
S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis)
if self.do_amp_to_db_mel: if self.do_amp_to_db_mel:
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
else:
S = self._linear_to_mel(np.abs(D))
return self.normalize(S).astype(np.float32) return self.normalize(S).astype(np.float32)
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" """Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
S = self.denormalize(spectrogram) S = self.denormalize(spectrogram)
S = self._db_to_amp(S) S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
# Reconstruct phase # Reconstruct phase
if self.preemphasis != 0: W = self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
return self._griffin_lim(S**self.power)
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" """Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
D = self.denormalize(mel_spectrogram) D = self.denormalize(mel_spectrogram)
S = self._db_to_amp(D) S = db_to_amp(x=D, gain=self.spec_gain, base=self.base)
S = self._mel_to_linear(S) # Convert back to linear S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear
if self.preemphasis != 0: W = self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
return self._griffin_lim(S**self.power)
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram. """Convert a full scale linear spectrogram output of a network to a melspectrogram.
@ -509,60 +467,22 @@ class AudioProcessor(object):
np.ndarray: Normalized melspectrogram. np.ndarray: Normalized melspectrogram.
""" """
S = self.denormalize(linear_spec) S = self.denormalize(linear_spec)
S = self._db_to_amp(S) S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
S = self._linear_to_mel(np.abs(S)) S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis)
S = self._amp_to_db(S) S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
mel = self.normalize(S) mel = self.normalize(S)
return mel return mel
### STFT and ISTFT ### def _griffin_lim(self, S):
def _stft(self, y: np.ndarray) -> np.ndarray: return griffin_lim(
"""Librosa STFT wrapper. spec=S,
num_iter=self.griffin_lim_iters,
Args:
y (np.ndarray): Audio signal.
Returns:
np.ndarray: Complex number array.
"""
return librosa.stft(
y=y,
n_fft=self.fft_size,
hop_length=self.hop_length, hop_length=self.hop_length,
win_length=self.win_length, win_length=self.win_length,
fft_size=self.fft_size,
pad_mode=self.stft_pad_mode, pad_mode=self.stft_pad_mode,
window="hann",
center=True,
) )
def _istft(self, y: np.ndarray) -> np.ndarray:
"""Librosa iSTFT wrapper."""
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
def _griffin_lim(self, S):
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
try:
S_complex = np.abs(S).astype(np.complex)
except AttributeError: # np.complex is deprecated since numpy 1.20.0
S_complex = np.abs(S).astype(complex)
y = self._istft(S_complex * angles)
if not np.isfinite(y).all():
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
return np.array([0.0])
for _ in range(self.griffin_lim_iters):
angles = np.exp(1j * np.angle(self._stft(y)))
y = self._istft(S_complex * angles)
return y
def compute_stft_paddings(self, x, pad_sides=1):
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
(first and final frames)"""
assert pad_sides in (1, 2)
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
if pad_sides == 1:
return 0, pad
return pad // 2, pad // 2 + pad % 2
def compute_f0(self, x: np.ndarray) -> np.ndarray: def compute_f0(self, x: np.ndarray) -> np.ndarray:
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
@ -581,8 +501,6 @@ class AudioProcessor(object):
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
>>> pitch = ap.compute_f0(wav) >>> pitch = ap.compute_f0(wav)
""" """
assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
assert self.pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`."
# align F0 length to the spectrogram length # align F0 length to the spectrogram length
if len(x) % self.hop_length == 0: if len(x) % self.hop_length == 0:
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode) x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
@ -612,21 +530,24 @@ class AudioProcessor(object):
Returns: Returns:
int: Last point without silence. int: Last point without silence.
""" """
window_length = int(self.sample_rate * min_silence_sec) return find_endpoint(
hop_length = int(window_length / 4) wav=wav,
threshold = self._db_to_amp(-self.trim_db) trim_db=self.trim_db,
for x in range(hop_length, len(wav) - window_length, hop_length): sample_rate=self.sample_rate,
if np.max(wav[x : x + window_length]) < threshold: min_silence_sec=min_silence_sec,
return x + hop_length gain=self.spec_gain,
return len(wav) base=self.base,
)
def trim_silence(self, wav): def trim_silence(self, wav):
"""Trim silent parts with a threshold and 0.01 sec margin""" """Trim silent parts with a threshold and 0.01 sec margin"""
margin = int(self.sample_rate * 0.01) return trim_silence(
wav = wav[margin:-margin] wav=wav,
return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[ sample_rate=self.sample_rate,
0 trim_db=self.trim_db,
] win_length=self.win_length,
hop_length=self.hop_length,
)
@staticmethod @staticmethod
def sound_norm(x: np.ndarray) -> np.ndarray: def sound_norm(x: np.ndarray) -> np.ndarray:
@ -638,13 +559,7 @@ class AudioProcessor(object):
Returns: Returns:
np.ndarray: Volume normalized waveform. np.ndarray: Volume normalized waveform.
""" """
return x / abs(x).max() * 0.95 return volume_norm(x=x)
@staticmethod
def _rms_norm(wav, db_level=-27):
r = 10 ** (db_level / 20)
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
return wav * a
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
"""Normalize the volume based on RMS of the signal. """Normalize the volume based on RMS of the signal.
@ -657,9 +572,7 @@ class AudioProcessor(object):
""" """
if db_level is None: if db_level is None:
db_level = self.db_level db_level = self.db_level
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" return rms_volume_norm(x=x, db_level=db_level)
wav = self._rms_norm(x, db_level)
return wav
### save and load ### ### save and load ###
def load_wav(self, filename: str, sr: int = None) -> np.ndarray: def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
@ -674,15 +587,10 @@ class AudioProcessor(object):
Returns: Returns:
np.ndarray: Loaded waveform. np.ndarray: Loaded waveform.
""" """
if self.resample: if sr is not None:
# loading with resampling. It is significantly slower. x = load_wav(filename=filename, sample_rate=sr, resample=True)
x, sr = librosa.load(filename, sr=self.sample_rate)
elif sr is None:
# SF is faster than librosa for loading files
x, sr = sf.read(filename)
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
else: else:
x, sr = librosa.load(filename, sr=sr) x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample)
if self.do_trim_silence: if self.do_trim_silence:
try: try:
x = self.trim_silence(x) x = self.trim_silence(x)
@ -723,55 +631,3 @@ class AudioProcessor(object):
filename (str): Path to the wav file. filename (str): Path to the wav file.
""" """
return librosa.get_duration(filename=filename) return librosa.get_duration(filename=filename)
@staticmethod
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
mu = 2**qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
# Quantize signal to the specified number of levels.
signal = (signal + 1) / 2 * mu + 0.5
return np.floor(
signal,
)
@staticmethod
def mulaw_decode(wav, qc):
"""Recovers waveform from quantized values."""
mu = 2**qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
@staticmethod
def encode_16bits(x):
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
@staticmethod
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
"""Quantize a waveform to a given number of bits.
Args:
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
bits (int): Number of quantization bits.
Returns:
np.ndarray: Quantized waveform.
"""
return (x + 1.0) * (2**bits - 1) / 2
@staticmethod
def dequantize(x, bits):
"""Dequantize a waveform from the given number of bits."""
return 2 * x / (2**bits - 1) - 1
def _log(x, base):
if base == 10:
return np.log10(x)
return np.log(x)
def _exp(x, base):
if base == 10:
return np.power(10, x)
return np.exp(x)

View File

@ -7,6 +7,7 @@ from coqpit import Coqpit
from tqdm import tqdm from tqdm import tqdm
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
@ -29,7 +30,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
mel = ap.melspectrogram(y) mel = ap.melspectrogram(y)
np.save(mel_path, mel) np.save(mel_path, mel)
if isinstance(config.mode, int): if isinstance(config.mode, int):
quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode) quant = (
mulaw_encode(wav=y, mulaw_qc=config.mode)
if config.model_args.mulaw
else quantize(x=y, quantize_bits=config.mode)
)
np.save(quant_path, quant) np.save(quant_path, quant)

View File

@ -2,6 +2,8 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
class WaveRNNDataset(Dataset): class WaveRNNDataset(Dataset):
""" """
@ -66,7 +68,9 @@ class WaveRNNDataset(Dataset):
x_input = audio x_input = audio
elif isinstance(self.mode, int): elif isinstance(self.mode, int):
x_input = ( x_input = (
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode) mulaw_encode(wav=audio, mulaw_qc=self.mode)
if self.mulaw
else quantize(x=audio, quantize_bits=self.mode)
) )
else: else:
raise RuntimeError("Unknown dataset mode - ", self.mode) raise RuntimeError("Unknown dataset mode - ", self.mode)

View File

@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.tts.utils.visual import plot_spectrogram from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_decode
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss from TTS.vocoder.layers.losses import WaveRNNLoss
@ -399,7 +400,7 @@ class Wavernn(BaseVocoder):
output = output[0] output = output[0]
if self.args.mulaw and isinstance(self.args.mode, int): if self.args.mulaw and isinstance(self.args.mode, int):
output = AudioProcessor.mulaw_decode(output, self.args.mode) output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
# Fade-out at the end to avoid signal cutting out suddenly # Fade-out at the end to avoid signal cutting out suddenly
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length) fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)

View File

@ -13,23 +13,28 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"import sys\n",
"import torch\n",
"import importlib\n", "import importlib\n",
"import numpy as np\n", "import os\n",
"from tqdm import tqdm\n",
"from torch.utils.data import DataLoader\n",
"import soundfile as sf\n",
"import pickle\n", "import pickle\n",
"\n",
"import numpy as np\n",
"import soundfile as sf\n",
"import torch\n",
"from matplotlib import pylab as plt\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm\n",
"\n",
"from TTS.config import load_config\n",
"from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
"from TTS.tts.datasets import load_tts_samples\n",
"from TTS.tts.datasets.dataset import TTSDataset\n", "from TTS.tts.datasets.dataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n", "from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.config import load_config\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.models import setup_model\n", "from TTS.tts.models import setup_model\n",
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n", "from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.audio.numpy_transforms import quantize\n",
"\n", "\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"\n", "\n",
@ -49,11 +54,9 @@
" file_name = wav_file.split('.')[0]\n", " file_name = wav_file.split('.')[0]\n",
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n", " os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n", " os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
" os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n",
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n", " wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
" mel_path = os.path.join(out_path, \"mel\", file_name)\n", " mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n", " return file_name, wavq_path, mel_path"
" return file_name, wavq_path, mel_path, wav_path"
] ]
}, },
{ {
@ -65,14 +68,14 @@
"# Paths and configurations\n", "# Paths and configurations\n",
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n", "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n", "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
"PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
"DATASET = \"ljspeech\"\n", "DATASET = \"ljspeech\"\n",
"METADATA_FILE = \"metadata.csv\"\n", "METADATA_FILE = \"metadata.csv\"\n",
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n", "CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n", "MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
"BATCH_SIZE = 32\n", "BATCH_SIZE = 32\n",
"\n", "\n",
"QUANTIZED_WAV = False\n", "QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n",
"QUANTIZE_BIT = None\n",
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n", "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
"\n", "\n",
"# Check CUDA availability\n", "# Check CUDA availability\n",
@ -80,10 +83,10 @@
"print(\" > CUDA enabled: \", use_cuda)\n", "print(\" > CUDA enabled: \", use_cuda)\n",
"\n", "\n",
"# Load the configuration\n", "# Load the configuration\n",
"dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
"C = load_config(CONFIG_PATH)\n", "C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n", "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n", "ap = AudioProcessor(**C.audio)"
"print(C['r'])"
] ]
}, },
{ {
@ -92,12 +95,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# If the vocabulary was passed, replace the default\n", "# Initialize the tokenizer\n",
"if 'characters' in C and C['characters']:\n", "tokenizer, C = TTSTokenizer.init_from_config(C)\n",
" symbols, phonemes = make_symbols(**C.characters)\n",
"\n", "\n",
"# Load the model\n", "# Load the model\n",
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
"# TODO: multiple speakers\n", "# TODO: multiple speakers\n",
"model = setup_model(C)\n", "model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)" "model.load_checkpoint(C, MODEL_FILE, eval=True)"
@ -109,42 +110,21 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Load the preprocessor based on the dataset\n", "# Load data instances\n",
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n", "meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n", "meta_data = meta_data_train + meta_data_eval\n",
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n", "\n",
"dataset = TTSDataset(\n", "dataset = TTSDataset(\n",
" C,\n", " outputs_per_step=C[\"r\"],\n",
" C.text_cleaner,\n", " compute_linear_spec=False,\n",
" False,\n", " ap=ap,\n",
" ap,\n", " samples=meta_data,\n",
" meta_data,\n", " tokenizer=tokenizer,\n",
" characters=C.get('characters', None),\n", " phoneme_cache_path=PHONEME_CACHE_PATH,\n",
" use_phonemes=C.use_phonemes,\n",
" phoneme_cache_path=C.phoneme_cache_path,\n",
" enable_eos_bos=C.enable_eos_bos_chars,\n",
")\n", ")\n",
"loader = DataLoader(\n", "loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n", " dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")\n" ")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Create log file\n",
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"log_file = open(log_file_path, \"w\")"
] ]
}, },
{ {
@ -160,26 +140,33 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Start processing with a progress bar\n", "# Start processing with a progress bar\n",
"with torch.no_grad():\n", "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
" for data in tqdm(loader, desc=\"Processing\"):\n", " for data in tqdm(loader, desc=\"Processing\"):\n",
" try:\n", " try:\n",
" # setup input data\n",
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
"\n",
" # dispatch data to GPU\n", " # dispatch data to GPU\n",
" if use_cuda:\n", " if use_cuda:\n",
" text_input = text_input.cuda()\n", " data[\"token_id\"] = data[\"token_id\"].cuda()\n",
" text_lengths = text_lengths.cuda()\n", " data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
" mel_input = mel_input.cuda()\n", " data[\"mel\"] = data[\"mel\"].cuda()\n",
" mel_lengths = mel_lengths.cuda()\n", " data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
"\n", "\n",
" mask = sequence_mask(text_lengths)\n", " mask = sequence_mask(data[\"token_id_lengths\"])\n",
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n", " outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
" mel_outputs = outputs[\"decoder_outputs\"]\n",
" postnet_outputs = outputs[\"model_outputs\"]\n",
"\n", "\n",
" # compute loss\n", " # compute loss\n",
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n", " loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n", " loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" losses.append(loss.item())\n", " losses.append(loss.item())\n",
" postnet_losses.append(loss_postnet.item())\n", " postnet_losses.append(loss_postnet.item())\n",
"\n", "\n",
@ -193,28 +180,27 @@
" postnet_outputs = torch.stack(mel_specs)\n", " postnet_outputs = torch.stack(mel_specs)\n",
" elif C.model == \"Tacotron2\":\n", " elif C.model == \"Tacotron2\":\n",
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n", " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
" alignments = alignments.detach().cpu().numpy()\n", " alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
"\n", "\n",
" if not DRY_RUN:\n", " if not DRY_RUN:\n",
" for idx in range(text_input.shape[0]):\n", " for idx in range(data[\"token_id\"].shape[0]):\n",
" wav_file_path = item_idx[idx]\n", " wav_file_path = data[\"item_idxs\"][idx]\n",
" wav = ap.load_wav(wav_file_path)\n", " wav = ap.load_wav(wav_file_path)\n",
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", " file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
" file_idxs.append(file_name)\n", " file_idxs.append(file_name)\n",
"\n", "\n",
" # quantize and save wav\n", " # quantize and save wav\n",
" if QUANTIZED_WAV:\n", " if QUANTIZE_BITS > 0:\n",
" wavq = ap.quantize(wav)\n", " wavq = quantize(wav, QUANTIZE_BITS)\n",
" np.save(wavq_path, wavq)\n", " np.save(wavq_path, wavq)\n",
"\n", "\n",
" # save TTS mel\n", " # save TTS mel\n",
" mel = postnet_outputs[idx]\n", " mel = postnet_outputs[idx]\n",
" mel_length = mel_lengths[idx]\n", " mel_length = data[\"mel_lengths\"][idx]\n",
" mel = mel[:mel_length, :].T\n", " mel = mel[:mel_length, :].T\n",
" np.save(mel_path, mel)\n", " np.save(mel_path, mel)\n",
"\n", "\n",
" metadata.append([wav_file_path, mel_path])\n", " metadata.append([wav_file_path, mel_path])\n",
"\n",
" except Exception as e:\n", " except Exception as e:\n",
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n", " log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
"\n", "\n",
@ -224,35 +210,20 @@
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n", " log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n", " log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
"\n", "\n",
"# Close the log file\n",
"log_file.close()\n",
"\n",
"# For wavernn\n", "# For wavernn\n",
"if not DRY_RUN:\n", "if not DRY_RUN:\n",
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n", " pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
"\n", "\n",
"# For pwgan\n", "# For pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n", " for wav_file_path, mel_path in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n", " f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
"\n", "\n",
"# Print mean losses\n", "# Print mean losses\n",
"print(f\"Mean Loss: {mean_loss}\")\n", "print(f\"Mean Loss: {mean_loss}\")\n",
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")" "print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@ -267,7 +238,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"idx = 1\n", "idx = 1\n",
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape" "ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
] ]
}, },
{ {
@ -276,10 +247,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import soundfile as sf\n", "wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
"wav, sr = sf.read(item_idx[idx])\n", "mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
"mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n", "mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n", "mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)" "print(mel_truth.shape)"
] ]
@ -291,7 +261,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# plot posnet output\n", "# plot posnet output\n",
"print(mel_postnet[:mel_lengths[idx], :].shape)\n", "print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)" "plot_spectrogram(mel_postnet, ap)"
] ]
}, },
@ -324,10 +294,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# postnet, decoder diff\n", "# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n",
"mel_diff = mel_decoder - mel_postnet\n", "mel_diff = mel_decoder - mel_postnet\n",
"plt.figure(figsize=(16, 10))\n", "plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ]
@ -339,10 +308,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# PLOT GT SPECTROGRAM diff\n", "# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
"mel_diff2 = mel_truth.T - mel_decoder\n", "mel_diff2 = mel_truth.T - mel_decoder\n",
"plt.figure(figsize=(16, 10))\n", "plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ]
@ -354,21 +322,13 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# PLOT GT SPECTROGRAM diff\n", "# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
"mel = postnet_outputs[idx]\n", "mel = postnet_outputs[idx]\n",
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n", "mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
"plt.figure(figsize=(16, 10))\n", "plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

View File

@ -5,6 +5,7 @@ import torch
from tests import get_tests_input_path, get_tests_output_path, get_tests_path from tests import get_tests_input_path, get_tests_output_path, get_tests_path
from TTS.config import BaseAudioConfig from TTS.config import BaseAudioConfig
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import stft
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
TESTS_PATH = get_tests_path() TESTS_PATH = get_tests_path()
@ -21,7 +22,7 @@ def test_torch_stft():
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length) torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
# librosa stft # librosa stft
wav = ap.load_wav(WAV_FILE) wav = ap.load_wav(WAV_FILE)
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length))
# torch stft # torch stft
wav = torch.from_numpy(wav[None, :]).float() wav = torch.from_numpy(wav[None, :]).float()
M_torch = torch_stft(wav) M_torch = torch_stft(wav)