mirror of https://github.com/coqui-ai/TTS.git
global shared Coqpit configs
parent
70fc7a7e71
commit
10db2baa06
|
@ -0,0 +1,258 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Union
|
||||
|
||||
from coqpit import MISSING, Coqpit, check_argument
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseAudioConfig(Coqpit):
|
||||
"""Base config to definge audio processing parameters. It is used to initialize
|
||||
```TTS.utils.audio.AudioProcessor.```
|
||||
|
||||
Args:
|
||||
fft_size (int):
|
||||
Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024.
|
||||
win_length (int):
|
||||
Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match
|
||||
```fft_size```. Defaults to 256.
|
||||
hop_length (int):
|
||||
Number of audio samples between adjacent STFT columns. Defaults to 1024.
|
||||
frame_shift_ms (int):
|
||||
Set ```hop_length``` based on milliseconds and sampling rate.
|
||||
frame_length_ms (int):
|
||||
Set ```win_length``` based on milliseconds and sampling rate.
|
||||
stft_pad_mode (str):
|
||||
Padding method used in STFT. 'reflect' or 'center'.
|
||||
sample_rate (int):
|
||||
Audio sampling rate. Defaults to 22050.
|
||||
resample (bool):
|
||||
Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```.
|
||||
preemphasis (float):
|
||||
Preemphasis coefficient. Defaults to 0.0.
|
||||
ref_level_db (int): 20
|
||||
Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air.
|
||||
Defaults to 20.
|
||||
do_sound_norm (bool):
|
||||
Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
|
||||
do_trim_silence (bool):
|
||||
Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
|
||||
trim_db (int):
|
||||
Silence threshold used for silence trimming. Defaults to 45.
|
||||
power (float):
|
||||
Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
|
||||
artifacts in the synthesized voice. Defaults to 1.5.
|
||||
griffin_lim_iters (int):
|
||||
Number of Griffing Lim iterations. Defaults to 60.
|
||||
num_mels (int):
|
||||
Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80.
|
||||
mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices.
|
||||
It needs to be adjusted for a dataset. Defaults to 0.
|
||||
mel_fmax (float):
|
||||
Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset.
|
||||
spec_gain (int):
|
||||
Gain applied when converting amplitude to DB. Defaults to 20.
|
||||
signal_norm (bool):
|
||||
enable/disable signal normalization. Defaults to True.
|
||||
min_level_db (int):
|
||||
minimum db threshold for the computed melspectrograms. Defaults to -100.
|
||||
symmetric_norm (bool):
|
||||
enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else
|
||||
[0, k], Defaults to True.
|
||||
max_norm (float):
|
||||
```k``` defining the normalization range. Defaults to 4.0.
|
||||
clip_norm (bool):
|
||||
enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
|
||||
stats_path (str):
|
||||
Path to the computed stats file. Defaults to None.
|
||||
"""
|
||||
|
||||
# stft parameters
|
||||
fft_size: int = 1024
|
||||
win_length: int = 1024
|
||||
hop_length: int = 256
|
||||
frame_shift_ms: int = None
|
||||
frame_length_ms: int = None
|
||||
stft_pad_mode: str = "reflect"
|
||||
# audio processing parameters
|
||||
sample_rate: int = 22050
|
||||
resample: bool = False
|
||||
preemphasis: float = 0.0
|
||||
ref_level_db: int = 20
|
||||
do_sound_norm: bool = False
|
||||
log_func = "np.log10"
|
||||
# silence trimming
|
||||
do_trim_silence: bool = True
|
||||
trim_db: int = 45
|
||||
# griffin-lim params
|
||||
power: float = 1.5
|
||||
griffin_lim_iters: int = 60
|
||||
# mel-spec params
|
||||
num_mels: int = 80
|
||||
mel_fmin: float = 0.0
|
||||
mel_fmax: float = None
|
||||
spec_gain: int = 20
|
||||
# normalization params
|
||||
signal_norm: bool = True
|
||||
min_level_db: int = -100
|
||||
symmetric_norm: bool = True
|
||||
max_norm: float = 4.0
|
||||
clip_norm: bool = True
|
||||
stats_path: str = None
|
||||
|
||||
def check_values(
|
||||
self,
|
||||
):
|
||||
"""Check config fields"""
|
||||
c = asdict(self)
|
||||
check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056)
|
||||
check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058)
|
||||
check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000)
|
||||
check_argument(
|
||||
"frame_length_ms",
|
||||
c,
|
||||
restricted=True,
|
||||
min_val=10,
|
||||
max_val=1000,
|
||||
alternative="win_length",
|
||||
)
|
||||
check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length")
|
||||
check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1)
|
||||
check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10)
|
||||
check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000)
|
||||
check_argument("power", c, restricted=True, min_val=1, max_val=5)
|
||||
check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000)
|
||||
|
||||
# normalization parameters
|
||||
check_argument("signal_norm", c, restricted=True)
|
||||
check_argument("symmetric_norm", c, restricted=True)
|
||||
check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000)
|
||||
check_argument("clip_norm", c, restricted=True)
|
||||
check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000)
|
||||
check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True)
|
||||
check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100)
|
||||
check_argument("do_trim_silence", c, restricted=True)
|
||||
check_argument("trim_db", c, restricted=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDatasetConfig(Coqpit):
|
||||
name: str = None
|
||||
path: str = None
|
||||
meta_file_train: str = None
|
||||
meta_file_val: str = None
|
||||
meta_file_attn_mask: str = None
|
||||
|
||||
def check_values(
|
||||
self,
|
||||
):
|
||||
"""Check config fields"""
|
||||
c = asdict(self)
|
||||
check_argument("name", c, restricted=True)
|
||||
check_argument("path", c, restricted=True)
|
||||
check_argument("meta_file_train", c, restricted=True)
|
||||
check_argument("meta_file_val", c, restricted=False)
|
||||
check_argument("meta_file_attn_mask", c, restricted=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseTrainingConfig(Coqpit):
|
||||
"""Base config to define the basic training parameters that are shared
|
||||
among all the models.
|
||||
|
||||
Args:
|
||||
batch_size (int):
|
||||
Training batch size.
|
||||
batch_group_size (int):
|
||||
Number of batches to shuffle after bucketing.
|
||||
eval_batch_size (int):
|
||||
Validation batch size.
|
||||
loss_masking (bool):
|
||||
Enable / Disable masking padding segments of sequences.
|
||||
mixed_precision (bool):
|
||||
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
|
||||
it may also cause numerical unstability in some cases.
|
||||
run_eval (bool):
|
||||
Enable / Disable evaluation (validation) run. Defaults to True.
|
||||
test_delay_epochs (int):
|
||||
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
|
||||
results, hence waiting for a couple of epochs might save some time.
|
||||
print_eval (bool):
|
||||
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
|
||||
the end of the evaluation. Default to ```False```.
|
||||
print_step (int):
|
||||
Number of steps required to print the next training log.
|
||||
tb_plot_step (int):
|
||||
Number of steps required to log training on Tensorboard.
|
||||
tb_model_param_stats (bool):
|
||||
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
|
||||
Defaults to ```False```.
|
||||
save_step (int):ipt
|
||||
Number of steps required to save the next checkpoint.
|
||||
checkpoint (bool):
|
||||
Enable / Disable checkpointing.
|
||||
keep_all_best (bool):
|
||||
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
|
||||
to ```False```.
|
||||
keep_after (int):
|
||||
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
|
||||
to 10000.
|
||||
text_cleaner (str):
|
||||
Text cleaner to be used at model training. It is set to be one of the cleaners in
|
||||
```TTS.tts.utils.text.cleaners```.
|
||||
enable_eos_bos_chars (bool):
|
||||
Enable / Disable using special characters indicating end-of-sentence and begining-of-sentence.
|
||||
num_loader_workers (int):
|
||||
Number of workers for training time dataloader.
|
||||
num_val_loader_workers (int):
|
||||
Number of workers for evaluation time dataloader.
|
||||
min_seq_len (int):
|
||||
Minimum sequence length to be used at training.
|
||||
max_seq_len (int):
|
||||
Maximum sequence length to be used at training. VRAM use at training depends on this parameter. Consider to
|
||||
decrease it if you get OOM errors.
|
||||
compute_f0 (bool):
|
||||
Return F0 frames from the dataloader. Defaults to ```False```.
|
||||
compute_input_seq_cache (bool):
|
||||
Enable / Disable computing and caching phonemes sequences from character sequences at the begining of the
|
||||
training. It allows faster data loading times and more precise max-min sequence prunning. Defaults
|
||||
to ```False```.
|
||||
output_path (str):
|
||||
Path for training output folder. The nonexist part of the given path is created automatically.
|
||||
All training outputs are saved there.
|
||||
phoneme_cache_path (str):
|
||||
Path to a folder to save the computed phoneme sequences.
|
||||
datasets (List[BaseDatasetConfig]):
|
||||
ist of DatasetConfig.
|
||||
|
||||
"""
|
||||
|
||||
model: str = None
|
||||
run_name: str = ""
|
||||
run_description: str = ""
|
||||
# training params
|
||||
epochs: int = 10000
|
||||
batch_size: int = MISSING
|
||||
eval_batch_size: int = None
|
||||
mixed_precision: bool = False
|
||||
# eval params
|
||||
run_eval: bool = True
|
||||
test_delay_epochs: int = 0
|
||||
print_eval: bool = False
|
||||
# logging
|
||||
print_step: int = 25
|
||||
tb_plot_step: int = 100
|
||||
tb_model_param_stats: bool = False
|
||||
# checkpointing
|
||||
save_step: int = 10000
|
||||
checkpoint: bool = True
|
||||
keep_all_best: bool = False
|
||||
keep_after: int = 10000
|
||||
# dataloading
|
||||
num_loader_workers: int = None
|
||||
num_val_loader_workers: int = None
|
||||
use_noise_augment: bool = False
|
||||
# paths
|
||||
output_path: str = None
|
||||
# distributed
|
||||
distributed_backend: str = "nccl"
|
||||
distributed_url: str = "tcp://localhost:54321"
|
Loading…
Reference in New Issue