Update BaseTrainingConfig

pull/1324/head
Eren Gölge 2022-02-03 15:37:57 +01:00
parent d3a58ed07a
commit aa81454721
1 changed files with 4 additions and 109 deletions

View File

@ -2,6 +2,7 @@ from dataclasses import asdict, dataclass
from typing import List
from coqpit import Coqpit, check_argument
from trainer import TrainerConfig
@dataclass
@ -237,130 +238,24 @@ class BaseDatasetConfig(Coqpit):
@dataclass
class BaseTrainingConfig(Coqpit):
"""Base config to define the basic training parameters that are shared
among all the models.
class BaseTrainingConfig(TrainerConfig):
"""Base config to define the basic 🐸TTS training parameters that are shared
among all the models. It is based on ```Trainer.TrainingConfig```.
Args:
model (str):
Name of the model that is used in the training.
run_name (str):
Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`.
run_description (str):
Short description of the experiment.
epochs (int):
Number training epochs. Defaults to 10000.
batch_size (int):
Training batch size.
eval_batch_size (int):
Validation batch size.
mixed_precision (bool):
Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
it may also cause numerical unstability in some cases.
scheduler_after_epoch (bool):
If true, run the scheduler step after each epoch else run it after each model step.
run_eval (bool):
Enable / Disable evaluation (validation) run. Defaults to True.
test_delay_epochs (int):
Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
results, hence waiting for a couple of epochs might save some time.
print_eval (bool):
Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
the end of the evaluation. Default to ```False```.
print_step (int):
Number of steps required to print the next training log.
log_dashboard (str): "tensorboard" or "wandb"
Set the experiment tracking tool
plot_step (int):
Number of steps required to log training on Tensorboard.
model_param_stats (bool):
Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
Defaults to ```False```.
project_name (str):
Name of the project. Defaults to config.model
wandb_entity (str):
Name of W&B entity/team. Enables collaboration across a team or org.
log_model_step (int):
Number of steps required to log a checkpoint as W&B artifact
save_step (int):
Number of steps required to save the next checkpoint.
checkpoint (bool):
Enable / Disable checkpointing.
keep_all_best (bool):
Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
to ```False```.
keep_after (int):
Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
to 10000.
num_loader_workers (int):
Number of workers for training time dataloader.
num_eval_loader_workers (int):
Number of workers for evaluation time dataloader.
output_path (str):
Path for training output folder, either a local file path or other
URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
S3 (s3://) paths. The nonexist part of the given path is created
automatically. All training artefacts are saved there.
"""
model: str = None
run_name: str = "coqui_tts"
run_description: str = ""
# training params
epochs: int = 10000
batch_size: int = None
eval_batch_size: int = None
mixed_precision: bool = False
scheduler_after_epoch: bool = False
# eval params
run_eval: bool = True
test_delay_epochs: int = 0
print_eval: bool = False
# logging
dashboard_logger: str = "tensorboard"
print_step: int = 25
plot_step: int = 100
model_param_stats: bool = False
project_name: str = None
log_model_step: int = None
wandb_entity: str = None
# checkpointing
save_step: int = 10000
checkpoint: bool = True
keep_all_best: bool = False
keep_after: int = 10000
# dataloading
num_loader_workers: int = 0
num_eval_loader_workers: int = 0
use_noise_augment: bool = False
use_language_weighted_sampler: bool = False
# paths
output_path: str = None
# distributed
distributed_backend: str = "nccl"
distributed_url: str = "tcp://localhost:54321"