mirror of https://github.com/coqui-ai/TTS.git
Remove duplicate AudioProcessor code and fix ExtractTTSpectrogram.ipynb (#3230)
* chore: remove unused argument * refactor(audio.processor): remove duplicate stft+griffin_lim * chore(audio.processor): remove unused compute_stft_paddings Same function available in numpy_transforms * refactor(audio.processor): remove duplicate db_to_amp * refactor(audio.processor): remove duplicate amp_to_db * refactor(audio.processor): remove duplicate linear_to_mel * refactor(audio.processor): remove duplicate mel_to_linear * refactor(audio.processor): remove duplicate build_mel_basis * refactor(audio.processor): remove duplicate stft_parameters * refactor(audio.processor): use pre-/deemphasis from numpy_transforms * refactor(audio.processor): use rms_volume_norm from numpy_transforms * chore(audio.processor): remove duplicate assert Already checked in numpy_transforms.compute_f0 * refactor(audio.processor): use find_endpoint from numpy_transforms * refactor(audio.processor): use trim_silence from numpy_transforms * refactor(audio.processor): use volume_norm from numpy_transforms * refactor(audio.processor): use load_wav from numpy_transforms * fix(bin.extract_tts_spectrograms): set quantization bits * fix(ExtractTTSpectrogram.ipynb): adapt to current TTS code Fixes #2447, #2574 * refactor(audio.processor): remove duplicate quantization methodspull/3238/head
parent
88630c60e5
commit
3c2d5a9e03
|
@ -15,6 +15,7 @@ from TTS.tts.models import setup_model
|
|||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import quantize
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
@ -159,7 +160,7 @@ def inference(
|
|||
|
||||
|
||||
def extract_spectrograms(
|
||||
data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"
|
||||
data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt"
|
||||
):
|
||||
model.eval()
|
||||
export_metadata = []
|
||||
|
@ -196,8 +197,8 @@ def extract_spectrograms(
|
|||
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
|
||||
|
||||
# quantize and save wav
|
||||
if quantized_wav:
|
||||
wavq = ap.quantize(wav)
|
||||
if quantize_bits > 0:
|
||||
wavq = quantize(wav, quantize_bits)
|
||||
np.save(wavq_path, wavq)
|
||||
|
||||
# save TTS mel
|
||||
|
@ -263,7 +264,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model,
|
||||
ap,
|
||||
args.output_path,
|
||||
quantized_wav=args.quantized,
|
||||
quantize_bits=args.quantize_bits,
|
||||
save_audio=args.save_audio,
|
||||
debug=args.debug,
|
||||
metada_name="metada.txt",
|
||||
|
@ -277,7 +278,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
|
||||
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
||||
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
||||
parser.add_argument("--quantized", action="store_true", help="Save quantized audio files")
|
||||
parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -201,7 +201,6 @@ def stft(
|
|||
def istft(
|
||||
*,
|
||||
y: np.ndarray = None,
|
||||
fft_size: int = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
window: str = "hann",
|
||||
|
|
|
@ -5,10 +5,26 @@ import librosa
|
|||
import numpy as np
|
||||
import scipy.io.wavfile
|
||||
import scipy.signal
|
||||
import soundfile as sf
|
||||
|
||||
from TTS.tts.utils.helpers import StandardScaler
|
||||
from TTS.utils.audio.numpy_transforms import compute_f0
|
||||
from TTS.utils.audio.numpy_transforms import (
|
||||
amp_to_db,
|
||||
build_mel_basis,
|
||||
compute_f0,
|
||||
db_to_amp,
|
||||
deemphasis,
|
||||
find_endpoint,
|
||||
griffin_lim,
|
||||
load_wav,
|
||||
mel_to_spec,
|
||||
millisec_to_length,
|
||||
preemphasis,
|
||||
rms_volume_norm,
|
||||
spec_to_mel,
|
||||
stft,
|
||||
trim_silence,
|
||||
volume_norm,
|
||||
)
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
|
||||
|
@ -200,7 +216,9 @@ class AudioProcessor(object):
|
|||
# setup stft parameters
|
||||
if hop_length is None:
|
||||
# compute stft parameters from given time values
|
||||
self.hop_length, self.win_length = self._stft_parameters()
|
||||
self.win_length, self.hop_length = millisec_to_length(
|
||||
frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate
|
||||
)
|
||||
else:
|
||||
# use stft parameters from config file
|
||||
self.hop_length = hop_length
|
||||
|
@ -215,8 +233,13 @@ class AudioProcessor(object):
|
|||
for key, value in members.items():
|
||||
print(" | > {}:{}".format(key, value))
|
||||
# create spectrogram utils
|
||||
self.mel_basis = self._build_mel_basis()
|
||||
self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis())
|
||||
self.mel_basis = build_mel_basis(
|
||||
sample_rate=self.sample_rate,
|
||||
fft_size=self.fft_size,
|
||||
num_mels=self.num_mels,
|
||||
mel_fmax=self.mel_fmax,
|
||||
mel_fmin=self.mel_fmin,
|
||||
)
|
||||
# setup scaler
|
||||
if stats_path and signal_norm:
|
||||
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
|
||||
|
@ -232,35 +255,6 @@ class AudioProcessor(object):
|
|||
return AudioProcessor(verbose=verbose, **config.audio)
|
||||
return AudioProcessor(verbose=verbose, **config)
|
||||
|
||||
### setting up the parameters ###
|
||||
def _build_mel_basis(
|
||||
self,
|
||||
) -> np.ndarray:
|
||||
"""Build melspectrogram basis.
|
||||
|
||||
Returns:
|
||||
np.ndarray: melspectrogram basis.
|
||||
"""
|
||||
if self.mel_fmax is not None:
|
||||
assert self.mel_fmax <= self.sample_rate // 2
|
||||
return librosa.filters.mel(
|
||||
sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
||||
)
|
||||
|
||||
def _stft_parameters(
|
||||
self,
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute the real STFT parameters from the time values.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: hop length and window length for STFT.
|
||||
"""
|
||||
factor = self.frame_length_ms / self.frame_shift_ms
|
||||
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
|
||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||
win_length = int(hop_length * factor)
|
||||
return hop_length, win_length
|
||||
|
||||
### normalization ###
|
||||
def normalize(self, S: np.ndarray) -> np.ndarray:
|
||||
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
|
||||
|
@ -386,31 +380,6 @@ class AudioProcessor(object):
|
|||
self.linear_scaler = StandardScaler()
|
||||
self.linear_scaler.set_stats(linear_mean, linear_std)
|
||||
|
||||
### DB and AMP conversion ###
|
||||
# pylint: disable=no-self-use
|
||||
def _amp_to_db(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Convert amplitude values to decibels.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Amplitude spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decibels spectrogram.
|
||||
"""
|
||||
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def _db_to_amp(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Convert decibels spectrogram to amplitude spectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Decibels spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Amplitude spectrogram.
|
||||
"""
|
||||
return _exp(x / self.spec_gain, self.base)
|
||||
|
||||
### Preemphasis ###
|
||||
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||
|
@ -424,32 +393,13 @@ class AudioProcessor(object):
|
|||
Returns:
|
||||
np.ndarray: Decorrelated audio signal.
|
||||
"""
|
||||
if self.preemphasis == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1, -self.preemphasis], [1], x)
|
||||
return preemphasis(x=x, coef=self.preemphasis)
|
||||
|
||||
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Reverse pre-emphasis."""
|
||||
if self.preemphasis == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1], [1, -self.preemphasis], x)
|
||||
return deemphasis(x=x, coef=self.preemphasis)
|
||||
|
||||
### SPECTROGRAMs ###
|
||||
def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Project a full scale spectrogram to a melspectrogram.
|
||||
|
||||
Args:
|
||||
spectrogram (np.ndarray): Full scale spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Melspectrogram
|
||||
"""
|
||||
return np.dot(self.mel_basis, spectrogram)
|
||||
|
||||
def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray:
|
||||
"""Convert a melspectrogram to full scale spectrogram."""
|
||||
return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec))
|
||||
|
||||
def spectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Compute a spectrogram from a waveform.
|
||||
|
||||
|
@ -460,11 +410,16 @@ class AudioProcessor(object):
|
|||
np.ndarray: Spectrogram.
|
||||
"""
|
||||
if self.preemphasis != 0:
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
else:
|
||||
D = self._stft(y)
|
||||
y = self.apply_preemphasis(y)
|
||||
D = stft(
|
||||
y=y,
|
||||
fft_size=self.fft_size,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
)
|
||||
if self.do_amp_to_db_linear:
|
||||
S = self._amp_to_db(np.abs(D))
|
||||
S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base)
|
||||
else:
|
||||
S = np.abs(D)
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
@ -472,32 +427,35 @@ class AudioProcessor(object):
|
|||
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Compute a melspectrogram from a waveform."""
|
||||
if self.preemphasis != 0:
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
else:
|
||||
D = self._stft(y)
|
||||
y = self.apply_preemphasis(y)
|
||||
D = stft(
|
||||
y=y,
|
||||
fft_size=self.fft_size,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
)
|
||||
S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis)
|
||||
if self.do_amp_to_db_mel:
|
||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
||||
else:
|
||||
S = self._linear_to_mel(np.abs(D))
|
||||
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
||||
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = self.denormalize(spectrogram)
|
||||
S = self._db_to_amp(S)
|
||||
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||
# Reconstruct phase
|
||||
if self.preemphasis != 0:
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||
return self._griffin_lim(S**self.power)
|
||||
W = self._griffin_lim(S**self.power)
|
||||
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||
|
||||
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
D = self.denormalize(mel_spectrogram)
|
||||
S = self._db_to_amp(D)
|
||||
S = self._mel_to_linear(S) # Convert back to linear
|
||||
if self.preemphasis != 0:
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||
return self._griffin_lim(S**self.power)
|
||||
S = db_to_amp(x=D, gain=self.spec_gain, base=self.base)
|
||||
S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear
|
||||
W = self._griffin_lim(S**self.power)
|
||||
return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W
|
||||
|
||||
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||
|
@ -509,60 +467,22 @@ class AudioProcessor(object):
|
|||
np.ndarray: Normalized melspectrogram.
|
||||
"""
|
||||
S = self.denormalize(linear_spec)
|
||||
S = self._db_to_amp(S)
|
||||
S = self._linear_to_mel(np.abs(S))
|
||||
S = self._amp_to_db(S)
|
||||
S = db_to_amp(x=S, gain=self.spec_gain, base=self.base)
|
||||
S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis)
|
||||
S = amp_to_db(x=S, gain=self.spec_gain, base=self.base)
|
||||
mel = self.normalize(S)
|
||||
return mel
|
||||
|
||||
### STFT and ISTFT ###
|
||||
def _stft(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Librosa STFT wrapper.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Audio signal.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Complex number array.
|
||||
"""
|
||||
return librosa.stft(
|
||||
y=y,
|
||||
n_fft=self.fft_size,
|
||||
def _griffin_lim(self, S):
|
||||
return griffin_lim(
|
||||
spec=S,
|
||||
num_iter=self.griffin_lim_iters,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
fft_size=self.fft_size,
|
||||
pad_mode=self.stft_pad_mode,
|
||||
window="hann",
|
||||
center=True,
|
||||
)
|
||||
|
||||
def _istft(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Librosa iSTFT wrapper."""
|
||||
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
|
||||
|
||||
def _griffin_lim(self, S):
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
try:
|
||||
S_complex = np.abs(S).astype(np.complex)
|
||||
except AttributeError: # np.complex is deprecated since numpy 1.20.0
|
||||
S_complex = np.abs(S).astype(complex)
|
||||
y = self._istft(S_complex * angles)
|
||||
if not np.isfinite(y).all():
|
||||
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
|
||||
return np.array([0.0])
|
||||
for _ in range(self.griffin_lim_iters):
|
||||
angles = np.exp(1j * np.angle(self._stft(y)))
|
||||
y = self._istft(S_complex * angles)
|
||||
return y
|
||||
|
||||
def compute_stft_paddings(self, x, pad_sides=1):
|
||||
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
||||
(first and final frames)"""
|
||||
assert pad_sides in (1, 2)
|
||||
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
|
||||
if pad_sides == 1:
|
||||
return 0, pad
|
||||
return pad // 2, pad // 2 + pad % 2
|
||||
|
||||
def compute_f0(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||
|
||||
|
@ -581,8 +501,6 @@ class AudioProcessor(object):
|
|||
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||
>>> pitch = ap.compute_f0(wav)
|
||||
"""
|
||||
assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
||||
assert self.pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`."
|
||||
# align F0 length to the spectrogram length
|
||||
if len(x) % self.hop_length == 0:
|
||||
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
|
||||
|
@ -612,21 +530,24 @@ class AudioProcessor(object):
|
|||
Returns:
|
||||
int: Last point without silence.
|
||||
"""
|
||||
window_length = int(self.sample_rate * min_silence_sec)
|
||||
hop_length = int(window_length / 4)
|
||||
threshold = self._db_to_amp(-self.trim_db)
|
||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||
if np.max(wav[x : x + window_length]) < threshold:
|
||||
return x + hop_length
|
||||
return len(wav)
|
||||
return find_endpoint(
|
||||
wav=wav,
|
||||
trim_db=self.trim_db,
|
||||
sample_rate=self.sample_rate,
|
||||
min_silence_sec=min_silence_sec,
|
||||
gain=self.spec_gain,
|
||||
base=self.base,
|
||||
)
|
||||
|
||||
def trim_silence(self, wav):
|
||||
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
||||
margin = int(self.sample_rate * 0.01)
|
||||
wav = wav[margin:-margin]
|
||||
return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[
|
||||
0
|
||||
]
|
||||
return trim_silence(
|
||||
wav=wav,
|
||||
sample_rate=self.sample_rate,
|
||||
trim_db=self.trim_db,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sound_norm(x: np.ndarray) -> np.ndarray:
|
||||
|
@ -638,13 +559,7 @@ class AudioProcessor(object):
|
|||
Returns:
|
||||
np.ndarray: Volume normalized waveform.
|
||||
"""
|
||||
return x / abs(x).max() * 0.95
|
||||
|
||||
@staticmethod
|
||||
def _rms_norm(wav, db_level=-27):
|
||||
r = 10 ** (db_level / 20)
|
||||
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
|
||||
return wav * a
|
||||
return volume_norm(x=x)
|
||||
|
||||
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
|
||||
"""Normalize the volume based on RMS of the signal.
|
||||
|
@ -657,9 +572,7 @@ class AudioProcessor(object):
|
|||
"""
|
||||
if db_level is None:
|
||||
db_level = self.db_level
|
||||
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
|
||||
wav = self._rms_norm(x, db_level)
|
||||
return wav
|
||||
return rms_volume_norm(x=x, db_level=db_level)
|
||||
|
||||
### save and load ###
|
||||
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
||||
|
@ -674,15 +587,10 @@ class AudioProcessor(object):
|
|||
Returns:
|
||||
np.ndarray: Loaded waveform.
|
||||
"""
|
||||
if self.resample:
|
||||
# loading with resampling. It is significantly slower.
|
||||
x, sr = librosa.load(filename, sr=self.sample_rate)
|
||||
elif sr is None:
|
||||
# SF is faster than librosa for loading files
|
||||
x, sr = sf.read(filename)
|
||||
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
|
||||
if sr is not None:
|
||||
x = load_wav(filename=filename, sample_rate=sr, resample=True)
|
||||
else:
|
||||
x, sr = librosa.load(filename, sr=sr)
|
||||
x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample)
|
||||
if self.do_trim_silence:
|
||||
try:
|
||||
x = self.trim_silence(x)
|
||||
|
@ -723,55 +631,3 @@ class AudioProcessor(object):
|
|||
filename (str): Path to the wav file.
|
||||
"""
|
||||
return librosa.get_duration(filename=filename)
|
||||
|
||||
@staticmethod
|
||||
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
|
||||
mu = 2**qc - 1
|
||||
# wav_abs = np.minimum(np.abs(wav), 1.0)
|
||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
||||
# Quantize signal to the specified number of levels.
|
||||
signal = (signal + 1) / 2 * mu + 0.5
|
||||
return np.floor(
|
||||
signal,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mulaw_decode(wav, qc):
|
||||
"""Recovers waveform from quantized values."""
|
||||
mu = 2**qc - 1
|
||||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def encode_16bits(x):
|
||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
||||
|
||||
@staticmethod
|
||||
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
|
||||
"""Quantize a waveform to a given number of bits.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
|
||||
bits (int): Number of quantization bits.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Quantized waveform.
|
||||
"""
|
||||
return (x + 1.0) * (2**bits - 1) / 2
|
||||
|
||||
@staticmethod
|
||||
def dequantize(x, bits):
|
||||
"""Dequantize a waveform from the given number of bits."""
|
||||
return 2 * x / (2**bits - 1) - 1
|
||||
|
||||
|
||||
def _log(x, base):
|
||||
if base == 10:
|
||||
return np.log10(x)
|
||||
return np.log(x)
|
||||
|
||||
|
||||
def _exp(x, base):
|
||||
if base == 10:
|
||||
return np.power(10, x)
|
||||
return np.exp(x)
|
||||
|
|
|
@ -7,6 +7,7 @@ from coqpit import Coqpit
|
|||
from tqdm import tqdm
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||
|
||||
|
||||
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
||||
|
@ -29,7 +30,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
|||
mel = ap.melspectrogram(y)
|
||||
np.save(mel_path, mel)
|
||||
if isinstance(config.mode, int):
|
||||
quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode)
|
||||
quant = (
|
||||
mulaw_encode(wav=y, mulaw_qc=config.mode)
|
||||
if config.model_args.mulaw
|
||||
else quantize(x=y, quantize_bits=config.mode)
|
||||
)
|
||||
np.save(quant_path, quant)
|
||||
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ import numpy as np
|
|||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||
|
||||
|
||||
class WaveRNNDataset(Dataset):
|
||||
"""
|
||||
|
@ -66,7 +68,9 @@ class WaveRNNDataset(Dataset):
|
|||
x_input = audio
|
||||
elif isinstance(self.mode, int):
|
||||
x_input = (
|
||||
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)
|
||||
mulaw_encode(wav=audio, mulaw_qc=self.mode)
|
||||
if self.mulaw
|
||||
else quantize(x=audio, quantize_bits=self.mode)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||
|
@ -399,7 +400,7 @@ class Wavernn(BaseVocoder):
|
|||
output = output[0]
|
||||
|
||||
if self.args.mulaw and isinstance(self.args.mode, int):
|
||||
output = AudioProcessor.mulaw_decode(output, self.args.mode)
|
||||
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
|
||||
|
||||
# Fade-out at the end to avoid signal cutting out suddenly
|
||||
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
|
||||
|
|
|
@ -13,23 +13,28 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import torch\n",
|
||||
"import importlib\n",
|
||||
"import numpy as np\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import soundfile as sf\n",
|
||||
"import os\n",
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import soundfile as sf\n",
|
||||
"import torch\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"from TTS.config import load_config\n",
|
||||
"from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
|
||||
"from TTS.tts.datasets import load_tts_samples\n",
|
||||
"from TTS.tts.datasets.dataset import TTSDataset\n",
|
||||
"from TTS.tts.layers.losses import L1LossMasked\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.config import load_config\n",
|
||||
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
||||
"from TTS.tts.utils.helpers import sequence_mask\n",
|
||||
"from TTS.tts.models import setup_model\n",
|
||||
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n",
|
||||
"from TTS.tts.utils.helpers import sequence_mask\n",
|
||||
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
|
||||
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.utils.audio.numpy_transforms import quantize\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
|
@ -49,11 +54,9 @@
|
|||
" file_name = wav_file.split('.')[0]\n",
|
||||
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
|
||||
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
|
||||
" os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n",
|
||||
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
|
||||
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
|
||||
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
|
||||
" return file_name, wavq_path, mel_path, wav_path"
|
||||
" return file_name, wavq_path, mel_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -65,14 +68,14 @@
|
|||
"# Paths and configurations\n",
|
||||
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
|
||||
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
|
||||
"PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
|
||||
"DATASET = \"ljspeech\"\n",
|
||||
"METADATA_FILE = \"metadata.csv\"\n",
|
||||
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
|
||||
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
|
||||
"BATCH_SIZE = 32\n",
|
||||
"\n",
|
||||
"QUANTIZED_WAV = False\n",
|
||||
"QUANTIZE_BIT = None\n",
|
||||
"QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n",
|
||||
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
||||
"\n",
|
||||
"# Check CUDA availability\n",
|
||||
|
@ -80,10 +83,10 @@
|
|||
"print(\" > CUDA enabled: \", use_cuda)\n",
|
||||
"\n",
|
||||
"# Load the configuration\n",
|
||||
"dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
|
||||
"C = load_config(CONFIG_PATH)\n",
|
||||
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
|
||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
|
||||
"print(C['r'])"
|
||||
"ap = AudioProcessor(**C.audio)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -92,12 +95,10 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# If the vocabulary was passed, replace the default\n",
|
||||
"if 'characters' in C and C['characters']:\n",
|
||||
" symbols, phonemes = make_symbols(**C.characters)\n",
|
||||
"# Initialize the tokenizer\n",
|
||||
"tokenizer, C = TTSTokenizer.init_from_config(C)\n",
|
||||
"\n",
|
||||
"# Load the model\n",
|
||||
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
|
||||
"# TODO: multiple speakers\n",
|
||||
"model = setup_model(C)\n",
|
||||
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
|
||||
|
@ -109,42 +110,21 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the preprocessor based on the dataset\n",
|
||||
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
|
||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
||||
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
|
||||
"# Load data instances\n",
|
||||
"meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
|
||||
"meta_data = meta_data_train + meta_data_eval\n",
|
||||
"\n",
|
||||
"dataset = TTSDataset(\n",
|
||||
" C,\n",
|
||||
" C.text_cleaner,\n",
|
||||
" False,\n",
|
||||
" ap,\n",
|
||||
" meta_data,\n",
|
||||
" characters=C.get('characters', None),\n",
|
||||
" use_phonemes=C.use_phonemes,\n",
|
||||
" phoneme_cache_path=C.phoneme_cache_path,\n",
|
||||
" enable_eos_bos=C.enable_eos_bos_chars,\n",
|
||||
" outputs_per_step=C[\"r\"],\n",
|
||||
" compute_linear_spec=False,\n",
|
||||
" ap=ap,\n",
|
||||
" samples=meta_data,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" phoneme_cache_path=PHONEME_CACHE_PATH,\n",
|
||||
")\n",
|
||||
"loader = DataLoader(\n",
|
||||
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize lists for storing results\n",
|
||||
"file_idxs = []\n",
|
||||
"metadata = []\n",
|
||||
"losses = []\n",
|
||||
"postnet_losses = []\n",
|
||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||
"\n",
|
||||
"# Create log file\n",
|
||||
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
||||
"log_file = open(log_file_path, \"w\")"
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -160,26 +140,33 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize lists for storing results\n",
|
||||
"file_idxs = []\n",
|
||||
"metadata = []\n",
|
||||
"losses = []\n",
|
||||
"postnet_losses = []\n",
|
||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||
"\n",
|
||||
"# Start processing with a progress bar\n",
|
||||
"with torch.no_grad():\n",
|
||||
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
||||
"with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
|
||||
" for data in tqdm(loader, desc=\"Processing\"):\n",
|
||||
" try:\n",
|
||||
" # setup input data\n",
|
||||
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
|
||||
"\n",
|
||||
" # dispatch data to GPU\n",
|
||||
" if use_cuda:\n",
|
||||
" text_input = text_input.cuda()\n",
|
||||
" text_lengths = text_lengths.cuda()\n",
|
||||
" mel_input = mel_input.cuda()\n",
|
||||
" mel_lengths = mel_lengths.cuda()\n",
|
||||
" data[\"token_id\"] = data[\"token_id\"].cuda()\n",
|
||||
" data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
|
||||
" data[\"mel\"] = data[\"mel\"].cuda()\n",
|
||||
" data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
|
||||
"\n",
|
||||
" mask = sequence_mask(text_lengths)\n",
|
||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
||||
" mask = sequence_mask(data[\"token_id_lengths\"])\n",
|
||||
" outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
|
||||
" mel_outputs = outputs[\"decoder_outputs\"]\n",
|
||||
" postnet_outputs = outputs[\"model_outputs\"]\n",
|
||||
"\n",
|
||||
" # compute loss\n",
|
||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
||||
" loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
|
||||
" loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
|
||||
" losses.append(loss.item())\n",
|
||||
" postnet_losses.append(loss_postnet.item())\n",
|
||||
"\n",
|
||||
|
@ -193,28 +180,27 @@
|
|||
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||
" elif C.model == \"Tacotron2\":\n",
|
||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||
" alignments = alignments.detach().cpu().numpy()\n",
|
||||
" alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
|
||||
"\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" for idx in range(text_input.shape[0]):\n",
|
||||
" wav_file_path = item_idx[idx]\n",
|
||||
" for idx in range(data[\"token_id\"].shape[0]):\n",
|
||||
" wav_file_path = data[\"item_idxs\"][idx]\n",
|
||||
" wav = ap.load_wav(wav_file_path)\n",
|
||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_idxs.append(file_name)\n",
|
||||
"\n",
|
||||
" # quantize and save wav\n",
|
||||
" if QUANTIZED_WAV:\n",
|
||||
" wavq = ap.quantize(wav)\n",
|
||||
" if QUANTIZE_BITS > 0:\n",
|
||||
" wavq = quantize(wav, QUANTIZE_BITS)\n",
|
||||
" np.save(wavq_path, wavq)\n",
|
||||
"\n",
|
||||
" # save TTS mel\n",
|
||||
" mel = postnet_outputs[idx]\n",
|
||||
" mel_length = mel_lengths[idx]\n",
|
||||
" mel_length = data[\"mel_lengths\"][idx]\n",
|
||||
" mel = mel[:mel_length, :].T\n",
|
||||
" np.save(mel_path, mel)\n",
|
||||
"\n",
|
||||
" metadata.append([wav_file_path, mel_path])\n",
|
||||
"\n",
|
||||
" except Exception as e:\n",
|
||||
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
|
||||
"\n",
|
||||
|
@ -224,35 +210,20 @@
|
|||
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
|
||||
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
|
||||
"\n",
|
||||
"# Close the log file\n",
|
||||
"log_file.close()\n",
|
||||
"\n",
|
||||
"# For wavernn\n",
|
||||
"if not DRY_RUN:\n",
|
||||
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
|
||||
"\n",
|
||||
"# For pwgan\n",
|
||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
||||
" for wav_file_path, mel_path in metadata:\n",
|
||||
" f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
|
||||
"\n",
|
||||
"# Print mean losses\n",
|
||||
"print(f\"Mean Loss: {mean_loss}\")\n",
|
||||
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# for pwgan\n",
|
||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -267,7 +238,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"idx = 1\n",
|
||||
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
|
||||
"ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -276,10 +247,9 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import soundfile as sf\n",
|
||||
"wav, sr = sf.read(item_idx[idx])\n",
|
||||
"mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n",
|
||||
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
|
||||
"wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
|
||||
"mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
|
||||
"mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
|
||||
"mel_truth = ap.melspectrogram(wav)\n",
|
||||
"print(mel_truth.shape)"
|
||||
]
|
||||
|
@ -291,7 +261,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# plot posnet output\n",
|
||||
"print(mel_postnet[:mel_lengths[idx], :].shape)\n",
|
||||
"print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
|
||||
"plot_spectrogram(mel_postnet, ap)"
|
||||
]
|
||||
},
|
||||
|
@ -324,10 +294,9 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# postnet, decoder diff\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
"mel_diff = mel_decoder - mel_postnet\n",
|
||||
"plt.figure(figsize=(16, 10))\n",
|
||||
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
|
||||
"plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||
"plt.colorbar()\n",
|
||||
"plt.tight_layout()"
|
||||
]
|
||||
|
@ -339,10 +308,9 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# PLOT GT SPECTROGRAM diff\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
"mel_diff2 = mel_truth.T - mel_decoder\n",
|
||||
"plt.figure(figsize=(16, 10))\n",
|
||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
|
||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||
"plt.colorbar()\n",
|
||||
"plt.tight_layout()"
|
||||
]
|
||||
|
@ -354,21 +322,13 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# PLOT GT SPECTROGRAM diff\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
"mel = postnet_outputs[idx]\n",
|
||||
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
|
||||
"plt.figure(figsize=(16, 10))\n",
|
||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
|
||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||
"plt.colorbar()\n",
|
||||
"plt.tight_layout()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||
from TTS.config import BaseAudioConfig
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import stft
|
||||
from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT
|
||||
|
||||
TESTS_PATH = get_tests_path()
|
||||
|
@ -21,7 +22,7 @@ def test_torch_stft():
|
|||
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
|
||||
# librosa stft
|
||||
wav = ap.load_wav(WAV_FILE)
|
||||
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
|
||||
M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length))
|
||||
# torch stft
|
||||
wav = torch.from_numpy(wav[None, :]).float()
|
||||
M_torch = torch_stft(wav)
|
||||
|
|
Loading…
Reference in New Issue