mirror of https://github.com/coqui-ai/TTS.git
Update BaseTrainingConfig
parent
d3a58ed07a
commit
aa81454721
|
@ -2,6 +2,7 @@ from dataclasses import asdict, dataclass
|
|||
from typing import List
|
||||
|
||||
from coqpit import Coqpit, check_argument
|
||||
from trainer import TrainerConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -237,130 +238,24 @@ class BaseDatasetConfig(Coqpit):
|
|||
|
||||
|
||||
@dataclass
|
||||
class BaseTrainingConfig(Coqpit):
|
||||
"""Base config to define the basic training parameters that are shared
|
||||
among all the models.
|
||||
class BaseTrainingConfig(TrainerConfig):
|
||||
"""Base config to define the basic 🐸TTS training parameters that are shared
|
||||
among all the models. It is based on ```Trainer.TrainingConfig```.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Name of the model that is used in the training.
|
||||
|
||||
run_name (str):
|
||||
Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`.
|
||||
|
||||
run_description (str):
|
||||
Short description of the experiment.
|
||||
|
||||
epochs (int):
|
||||
Number training epochs. Defaults to 10000.
|
||||
|
||||
batch_size (int):
|
||||
Training batch size.
|
||||
|
||||
eval_batch_size (int):
|
||||
Validation batch size.
|
||||
|
||||
mixed_precision (bool):
|
||||
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
|
||||
it may also cause numerical unstability in some cases.
|
||||
|
||||
scheduler_after_epoch (bool):
|
||||
If true, run the scheduler step after each epoch else run it after each model step.
|
||||
|
||||
run_eval (bool):
|
||||
Enable / Disable evaluation (validation) run. Defaults to True.
|
||||
|
||||
test_delay_epochs (int):
|
||||
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
|
||||
results, hence waiting for a couple of epochs might save some time.
|
||||
|
||||
print_eval (bool):
|
||||
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
|
||||
the end of the evaluation. Default to ```False```.
|
||||
|
||||
print_step (int):
|
||||
Number of steps required to print the next training log.
|
||||
|
||||
log_dashboard (str): "tensorboard" or "wandb"
|
||||
Set the experiment tracking tool
|
||||
|
||||
plot_step (int):
|
||||
Number of steps required to log training on Tensorboard.
|
||||
|
||||
model_param_stats (bool):
|
||||
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
|
||||
Defaults to ```False```.
|
||||
|
||||
project_name (str):
|
||||
Name of the project. Defaults to config.model
|
||||
|
||||
wandb_entity (str):
|
||||
Name of W&B entity/team. Enables collaboration across a team or org.
|
||||
|
||||
log_model_step (int):
|
||||
Number of steps required to log a checkpoint as W&B artifact
|
||||
|
||||
save_step (int):
|
||||
Number of steps required to save the next checkpoint.
|
||||
|
||||
checkpoint (bool):
|
||||
Enable / Disable checkpointing.
|
||||
|
||||
keep_all_best (bool):
|
||||
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
|
||||
to ```False```.
|
||||
|
||||
keep_after (int):
|
||||
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
|
||||
to 10000.
|
||||
|
||||
num_loader_workers (int):
|
||||
Number of workers for training time dataloader.
|
||||
|
||||
num_eval_loader_workers (int):
|
||||
Number of workers for evaluation time dataloader.
|
||||
|
||||
output_path (str):
|
||||
Path for training output folder, either a local file path or other
|
||||
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
|
||||
S3 (s3://) paths. The nonexist part of the given path is created
|
||||
automatically. All training artefacts are saved there.
|
||||
"""
|
||||
|
||||
model: str = None
|
||||
run_name: str = "coqui_tts"
|
||||
run_description: str = ""
|
||||
# training params
|
||||
epochs: int = 10000
|
||||
batch_size: int = None
|
||||
eval_batch_size: int = None
|
||||
mixed_precision: bool = False
|
||||
scheduler_after_epoch: bool = False
|
||||
# eval params
|
||||
run_eval: bool = True
|
||||
test_delay_epochs: int = 0
|
||||
print_eval: bool = False
|
||||
# logging
|
||||
dashboard_logger: str = "tensorboard"
|
||||
print_step: int = 25
|
||||
plot_step: int = 100
|
||||
model_param_stats: bool = False
|
||||
project_name: str = None
|
||||
log_model_step: int = None
|
||||
wandb_entity: str = None
|
||||
# checkpointing
|
||||
save_step: int = 10000
|
||||
checkpoint: bool = True
|
||||
keep_all_best: bool = False
|
||||
keep_after: int = 10000
|
||||
# dataloading
|
||||
num_loader_workers: int = 0
|
||||
num_eval_loader_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
use_language_weighted_sampler: bool = False
|
||||
|
||||
# paths
|
||||
output_path: str = None
|
||||
# distributed
|
||||
distributed_backend: str = "nccl"
|
||||
distributed_url: str = "tcp://localhost:54321"
|
||||
|
|
Loading…
Reference in New Issue