config update WIP

pull/476/head
Eren Gölge 2021-03-20 00:50:15 +01:00
parent 06f80a4806
commit e092ae40dc
2 changed files with 50 additions and 36 deletions

View File

@ -242,32 +242,32 @@ def check_config_tts(c):
check_argument("trim_db", c["audio"], restricted=True, val_type=int)
# training parameters
check_argument("batch_size", c, restricted=True, val_type=int, min_val=1)
check_argument("eval_batch_size", c, restricted=True, val_type=int, min_val=1)
check_argument("r", c, restricted=True, val_type=int, min_val=1)
check_argument("gradual_training", c, restricted=False, val_type=list)
check_argument("mixed_precision", c, restricted=False, val_type=bool)
# check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
# check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
check_argument('r', c, restricted=True, val_type=int, min_val=1)
check_argument('gradual_training', c, restricted=False, val_type=list)
# check_argument('mixed_precision', c, restricted=False, val_type=bool)
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
# loss parameters
check_argument("loss_masking", c, restricted=True, val_type=bool)
if c["model"].lower() in ["tacotron", "tacotron2"]:
check_argument("decoder_loss_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("postnet_loss_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("postnet_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("decoder_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("decoder_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("postnet_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("ga_alpha", c, restricted=True, val_type=float, min_val=0)
if c["model"].lower in ["speedy_speech", "align_tts"]:
check_argument("ssim_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("l1_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("huber_alpha", c, restricted=True, val_type=float, min_val=0)
# check_argument('loss_masking', c, restricted=True, val_type=bool)
if c['model'].lower() in ['tacotron', 'tacotron2']:
check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
if c['model'].lower in ["speedy_speech", "align_tts"]:
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
# validation parameters
check_argument("run_eval", c, restricted=True, val_type=bool)
check_argument("test_delay_epochs", c, restricted=True, val_type=int, min_val=0)
check_argument("test_sentences_file", c, restricted=False, val_type=str)
# check_argument('run_eval', c, restricted=True, val_type=bool)
# check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
# check_argument('test_sentences_file', c, restricted=False, val_type=str)
# optimizer
check_argument("noam_schedule", c, restricted=False, val_type=bool)
@ -319,24 +319,23 @@ def check_config_tts(c):
check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str)
# tensorboard
check_argument("print_step", c, restricted=True, val_type=int, min_val=1)
check_argument("tb_plot_step", c, restricted=True, val_type=int, min_val=1)
check_argument("save_step", c, restricted=True, val_type=int, min_val=1)
check_argument("checkpoint", c, restricted=True, val_type=bool)
check_argument("tb_model_param_stats", c, restricted=True, val_type=bool)
# check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
# check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
# check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
# check_argument('checkpoint', c, restricted=True, val_type=bool)
# check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
# dataloading
# pylint: disable=import-outside-toplevel
from TTS.tts.utils.text import cleaners
check_argument("text_cleaner", c, restricted=True, val_type=str, enum_list=dir(cleaners))
check_argument("enable_eos_bos_chars", c, restricted=True, val_type=bool)
check_argument("num_loader_workers", c, restricted=True, val_type=int, min_val=0)
check_argument("num_val_loader_workers", c, restricted=True, val_type=int, min_val=0)
check_argument("batch_group_size", c, restricted=True, val_type=int, min_val=0)
check_argument("min_seq_len", c, restricted=True, val_type=int, min_val=0)
check_argument("max_seq_len", c, restricted=True, val_type=int, min_val=10)
check_argument("compute_input_seq_cache", c, restricted=True, val_type=bool)
# check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
# check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
# check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
# check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
# check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
# check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
# check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
# check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool)
# paths
check_argument("output_path", c, restricted=True, val_type=str)

View File

@ -5,6 +5,7 @@ import shutil
import subprocess
import sys
from pathlib import Path
from typing import List
def get_git_branch():
@ -139,7 +140,21 @@ class KeepAverage:
self.update_value(key, value)
def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, alternative=None, allow_none=False):
def check_argument(name,
c,
prerequest=None,
enum_list=None,
max_val=None,
min_val=None,
restricted=False,
alternative=None,
allow_none=False):
if isinstance(prerequest, List()):
if any([f not in c.keys() for f in prerequest]):
return
else:
if prerequest not in c.keys():
return
if alternative in c.keys() and c[alternative] is not None:
return
if allow_none and c[name] is None: