mirror of https://github.com/coqui-ai/TTS.git
config update WIP
parent
06f80a4806
commit
e092ae40dc
|
@ -242,32 +242,32 @@ def check_config_tts(c):
|
|||
check_argument("trim_db", c["audio"], restricted=True, val_type=int)
|
||||
|
||||
# training parameters
|
||||
check_argument("batch_size", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("eval_batch_size", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("r", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("gradual_training", c, restricted=False, val_type=list)
|
||||
check_argument("mixed_precision", c, restricted=False, val_type=bool)
|
||||
# check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('gradual_training', c, restricted=False, val_type=list)
|
||||
# check_argument('mixed_precision', c, restricted=False, val_type=bool)
|
||||
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
|
||||
# loss parameters
|
||||
check_argument("loss_masking", c, restricted=True, val_type=bool)
|
||||
if c["model"].lower() in ["tacotron", "tacotron2"]:
|
||||
check_argument("decoder_loss_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("postnet_loss_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("postnet_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("decoder_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("decoder_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("postnet_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("ga_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
if c["model"].lower in ["speedy_speech", "align_tts"]:
|
||||
check_argument("ssim_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("l1_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("huber_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('loss_masking', c, restricted=True, val_type=bool)
|
||||
if c['model'].lower() in ['tacotron', 'tacotron2']:
|
||||
check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
|
||||
# validation parameters
|
||||
check_argument("run_eval", c, restricted=True, val_type=bool)
|
||||
check_argument("test_delay_epochs", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("test_sentences_file", c, restricted=False, val_type=str)
|
||||
# check_argument('run_eval', c, restricted=True, val_type=bool)
|
||||
# check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('test_sentences_file', c, restricted=False, val_type=str)
|
||||
|
||||
# optimizer
|
||||
check_argument("noam_schedule", c, restricted=False, val_type=bool)
|
||||
|
@ -319,24 +319,23 @@ def check_config_tts(c):
|
|||
check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str)
|
||||
|
||||
# tensorboard
|
||||
check_argument("print_step", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("tb_plot_step", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("save_step", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("checkpoint", c, restricted=True, val_type=bool)
|
||||
check_argument("tb_model_param_stats", c, restricted=True, val_type=bool)
|
||||
# check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
# check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
|
||||
# dataloading
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from TTS.tts.utils.text import cleaners
|
||||
|
||||
check_argument("text_cleaner", c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
||||
check_argument("enable_eos_bos_chars", c, restricted=True, val_type=bool)
|
||||
check_argument("num_loader_workers", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("num_val_loader_workers", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("batch_group_size", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("min_seq_len", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("max_seq_len", c, restricted=True, val_type=int, min_val=10)
|
||||
check_argument("compute_input_seq_cache", c, restricted=True, val_type=bool)
|
||||
# check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
||||
# check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
||||
# check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
|
||||
# check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool)
|
||||
|
||||
# paths
|
||||
check_argument("output_path", c, restricted=True, val_type=str)
|
||||
|
|
|
@ -5,6 +5,7 @@ import shutil
|
|||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
|
@ -139,7 +140,21 @@ class KeepAverage:
|
|||
self.update_value(key, value)
|
||||
|
||||
|
||||
def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, alternative=None, allow_none=False):
|
||||
def check_argument(name,
|
||||
c,
|
||||
prerequest=None,
|
||||
enum_list=None,
|
||||
max_val=None,
|
||||
min_val=None,
|
||||
restricted=False,
|
||||
alternative=None,
|
||||
allow_none=False):
|
||||
if isinstance(prerequest, List()):
|
||||
if any([f not in c.keys() for f in prerequest]):
|
||||
return
|
||||
else:
|
||||
if prerequest not in c.keys():
|
||||
return
|
||||
if alternative in c.keys() and c[alternative] is not None:
|
||||
return
|
||||
if allow_none and c[name] is None:
|
||||
|
|
Loading…
Reference in New Issue