mirror of https://github.com/coqui-ai/TTS.git
parent
4546b4cbd8
commit
371772c355
|
@ -62,7 +62,7 @@ class BaseAudioConfig(Coqpit):
|
||||||
Maximum frequency of the F0 frames. Defaults to ```640```.
|
Maximum frequency of the F0 frames. Defaults to ```640```.
|
||||||
|
|
||||||
pitch_fmin (float, optional):
|
pitch_fmin (float, optional):
|
||||||
Minimum frequency of the F0 frames. Defaults to ```0```.
|
Minimum frequency of the F0 frames. Defaults to ```1```.
|
||||||
|
|
||||||
trim_db (int):
|
trim_db (int):
|
||||||
Silence threshold used for silence trimming. Defaults to 45.
|
Silence threshold used for silence trimming. Defaults to 45.
|
||||||
|
@ -144,7 +144,7 @@ class BaseAudioConfig(Coqpit):
|
||||||
do_amp_to_db_mel: bool = True
|
do_amp_to_db_mel: bool = True
|
||||||
# f0 params
|
# f0 params
|
||||||
pitch_fmax: float = 640.0
|
pitch_fmax: float = 640.0
|
||||||
pitch_fmin: float = 0.0
|
pitch_fmin: float = 1.0
|
||||||
# normalization params
|
# normalization params
|
||||||
signal_norm: bool = True
|
signal_norm: bool = True
|
||||||
min_level_db: int = -100
|
min_level_db: int = -100
|
||||||
|
|
|
@ -2,9 +2,9 @@ from typing import Tuple
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyworld as pw
|
|
||||||
import scipy
|
import scipy
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
from librosa import pyin
|
||||||
|
|
||||||
# For using kwargs
|
# For using kwargs
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
|
@ -242,12 +242,28 @@ def compute_stft_paddings(
|
||||||
|
|
||||||
|
|
||||||
def compute_f0(
|
def compute_f0(
|
||||||
*, x: np.ndarray = None, pitch_fmax: float = None, hop_length: int = None, sample_rate: int = None, **kwargs
|
*,
|
||||||
|
x: np.ndarray = None,
|
||||||
|
pitch_fmax: float = None,
|
||||||
|
pitch_fmin: float = None,
|
||||||
|
hop_length: int = None,
|
||||||
|
win_length: int = None,
|
||||||
|
sample_rate: int = None,
|
||||||
|
stft_pad_mode: str = "reflect",
|
||||||
|
center: bool = True,
|
||||||
|
**kwargs,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||||
|
pitch_fmax (float): Pitch max value.
|
||||||
|
pitch_fmin (float): Pitch min value.
|
||||||
|
hop_length (int): Number of frames between STFT columns.
|
||||||
|
win_length (int): STFT window length.
|
||||||
|
sample_rate (int): Audio sampling rate.
|
||||||
|
stft_pad_mode (str): Padding mode for STFT.
|
||||||
|
center (bool): Centered padding.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length`
|
np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length`
|
||||||
|
@ -255,20 +271,35 @@ def compute_f0(
|
||||||
Examples:
|
Examples:
|
||||||
>>> WAV_FILE = filename = librosa.util.example_audio_file()
|
>>> WAV_FILE = filename = librosa.util.example_audio_file()
|
||||||
>>> from TTS.config import BaseAudioConfig
|
>>> from TTS.config import BaseAudioConfig
|
||||||
>>> from TTS.utils.audio.processor import AudioProcessor >>> conf = BaseAudioConfig(pitch_fmax=8000)
|
>>> from TTS.utils.audio import AudioProcessor
|
||||||
|
>>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1)
|
||||||
>>> ap = AudioProcessor(**conf)
|
>>> ap = AudioProcessor(**conf)
|
||||||
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
|
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||||
>>> pitch = ap.compute_f0(wav)
|
>>> pitch = ap.compute_f0(wav)
|
||||||
"""
|
"""
|
||||||
assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
||||||
|
assert pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`."
|
||||||
|
|
||||||
f0, t = pw.dio(
|
f0, voiced_mask, _ = pyin(
|
||||||
x.astype(np.double),
|
y=x.astype(np.double),
|
||||||
fs=sample_rate,
|
fmin=pitch_fmin,
|
||||||
f0_ceil=pitch_fmax,
|
fmax=pitch_fmax,
|
||||||
frame_period=1000 * hop_length / sample_rate,
|
sr=sample_rate,
|
||||||
|
frame_length=win_length,
|
||||||
|
win_length=win_length // 2,
|
||||||
|
hop_length=hop_length,
|
||||||
|
pad_mode=stft_pad_mode,
|
||||||
|
center=center,
|
||||||
|
n_thresholds=100,
|
||||||
|
beta_parameters=(2, 18),
|
||||||
|
boltzmann_parameter=2,
|
||||||
|
resolution=0.1,
|
||||||
|
max_transition_rate=35.92,
|
||||||
|
switch_prob=0.01,
|
||||||
|
no_trough_prob=0.01,
|
||||||
)
|
)
|
||||||
f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate)
|
f0[~voiced_mask] = 0.0
|
||||||
|
|
||||||
return f0
|
return f0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,12 +2,12 @@ from typing import Dict, Tuple
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyworld as pw
|
|
||||||
import scipy.io.wavfile
|
import scipy.io.wavfile
|
||||||
import scipy.signal
|
import scipy.signal
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import StandardScaler
|
from TTS.tts.utils.helpers import StandardScaler
|
||||||
|
from TTS.utils.audio.numpy_transforms import compute_f0
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods
|
# pylint: disable=too-many-public-methods
|
||||||
|
|
||||||
|
@ -573,23 +573,28 @@ class AudioProcessor(object):
|
||||||
>>> WAV_FILE = filename = librosa.util.example_audio_file()
|
>>> WAV_FILE = filename = librosa.util.example_audio_file()
|
||||||
>>> from TTS.config import BaseAudioConfig
|
>>> from TTS.config import BaseAudioConfig
|
||||||
>>> from TTS.utils.audio import AudioProcessor
|
>>> from TTS.utils.audio import AudioProcessor
|
||||||
>>> conf = BaseAudioConfig(pitch_fmax=8000)
|
>>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1)
|
||||||
>>> ap = AudioProcessor(**conf)
|
>>> ap = AudioProcessor(**conf)
|
||||||
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
|
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||||
>>> pitch = ap.compute_f0(wav)
|
>>> pitch = ap.compute_f0(wav)
|
||||||
"""
|
"""
|
||||||
assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
||||||
|
assert self.pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`."
|
||||||
# align F0 length to the spectrogram length
|
# align F0 length to the spectrogram length
|
||||||
if len(x) % self.hop_length == 0:
|
if len(x) % self.hop_length == 0:
|
||||||
x = np.pad(x, (0, self.hop_length // 2), mode="reflect")
|
x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode)
|
||||||
|
|
||||||
f0, t = pw.dio(
|
f0 = compute_f0(
|
||||||
x.astype(np.double),
|
x=x,
|
||||||
fs=self.sample_rate,
|
pitch_fmax=self.pitch_fmax,
|
||||||
f0_ceil=self.pitch_fmax,
|
pitch_fmin=self.pitch_fmin,
|
||||||
frame_period=1000 * self.hop_length / self.sample_rate,
|
hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length,
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
stft_pad_mode=self.stft_pad_mode,
|
||||||
|
center=True,
|
||||||
)
|
)
|
||||||
f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
|
||||||
return f0
|
return f0
|
||||||
|
|
||||||
### Audio Processing ###
|
### Audio Processing ###
|
||||||
|
|
|
@ -23,7 +23,6 @@ umap-learn==0.5.1
|
||||||
pandas
|
pandas
|
||||||
# deps for training
|
# deps for training
|
||||||
matplotlib
|
matplotlib
|
||||||
pyworld==0.2.10 # > 0.2.10 is not p3.10.x compatible
|
|
||||||
# coqui stack
|
# coqui stack
|
||||||
trainer
|
trainer
|
||||||
# config management
|
# config management
|
||||||
|
|
|
@ -10,7 +10,7 @@ OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
|
||||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
|
||||||
os.makedirs(OUT_PATH, exist_ok=True)
|
os.makedirs(OUT_PATH, exist_ok=True)
|
||||||
conf = BaseAudioConfig(mel_fmax=8000)
|
conf = BaseAudioConfig(mel_fmax=8000, pitch_fmax=640, pitch_fmin=1)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
|
|
@ -31,7 +31,8 @@ class TestNumpyTransforms(unittest.TestCase):
|
||||||
mel_fmin: int = 0
|
mel_fmin: int = 0
|
||||||
hop_length: int = 256
|
hop_length: int = 256
|
||||||
win_length: int = 1024
|
win_length: int = 1024
|
||||||
pitch_fmax: int = 450
|
pitch_fmax: int = 640
|
||||||
|
pitch_fmin: int = 1
|
||||||
trim_db: int = -1
|
trim_db: int = -1
|
||||||
min_silence_sec: float = 0.01
|
min_silence_sec: float = 0.01
|
||||||
gain: float = 1.0
|
gain: float = 1.0
|
||||||
|
|
Loading…
Reference in New Issue