From 7ab527d17e6ea91f6c7f9919a59bd4b98148eaa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 12 Feb 2021 12:06:46 +0000 Subject: [PATCH] save default model chars to the training config file --- TTS/utils/arguments.py | 71 ++++++++++++++++++++---------------------- TTS/utils/io.py | 2 +- 2 files changed, 34 insertions(+), 39 deletions(-) diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py index 948c90d3..031a3140 100644 --- a/TTS/utils/arguments.py +++ b/TTS/utils/arguments.py @@ -3,17 +3,18 @@ """Argument parser for training scripts.""" import argparse -import re import glob import os - -from TTS.utils.generic_utils import ( - create_experiment_folder, get_git_branch) -from TTS.utils.console_logger import ConsoleLogger -from TTS.utils.io import copy_model_files, load_config -from TTS.utils.tensorboard_logger import TensorboardLogger +import re +import json from TTS.tts.utils.generic_utils import check_config_tts +from TTS.utils.console_logger import ConsoleLogger +from TTS.utils.generic_utils import create_experiment_folder, get_git_branch +from TTS.utils.io import (copy_model_files, load_config, + save_characters_to_config) +from TTS.utils.tensorboard_logger import TensorboardLogger +from TTS.tts.utils.text.symbols import parse_symbols def parse_arguments(argv): @@ -110,38 +111,27 @@ def get_last_checkpoint(path): def process_args(args, model_type): """Process parsed comand line arguments. - Parameters - ---------- - args : argparse.Namespace or dict like - Parsed input arguments. - model_type : str - Model type used to check config parameters and setup the TensorBoard - logger. One of: - - tacotron - - glow_tts - - speedy_speech - - gan - - wavegrad - - wavernn + Args: + args (argparse.Namespace or dict like): Parsed input arguments. + model_type (str): Model type used to check config parameters and setup the TensorBoard + logger. One of: + - tacotron + - glow_tts + - speedy_speech + - gan + - wavegrad + - wavernn - Raises - ------ - ValueError - If `model_type` is not one of implemented choices. - - Returns - ------- - c : TTS.utils.io.AttrDict - Config paramaters. - out_path : str - Path to save models and logging. - audio_path : str - Path to save generated test audios. - c_logger : TTS.utils.console_logger.ConsoleLogger - Class that does logging to the console. - tb_logger : TTS.utils.tensorboard.TensorboardLogger - Class that does the TensorBoard loggind. + Raises: + ValueError + If `model_type` is not one of implemented choices. + Returns: + c (TTS.utils.io.AttrDict): Config paramaters. + out_path (str): Path to save models and logging. + audio_path (str): Path to save generated test audios. + c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does logging to the console. + tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does the TensorBoard loggind. """ if args.continue_path != "": args.output_path = args.continue_path @@ -156,7 +146,6 @@ def process_args(args, model_type): # setup output paths and read configs c = load_config(args.config_path) - if model_type in "tacotron glow_tts speedy_speech": model_class = "TTS" elif model_type in "gan wavegrad wavernn": @@ -192,6 +181,12 @@ def process_args(args, model_type): if args.restore_path: new_fields["restore_path"] = args.restore_path new_fields["github_branch"] = get_git_branch() + # if model characters are not set in the config file + # save the default set to the config file for future + # compatibility. + if model_class == 'TTS' and not 'characters' in c.keys(): + used_characters = parse_symbols() + new_fields['characters'] = used_characters copy_model_files(c, args.config_path, out_path, new_fields) os.chmod(audio_path, 0o775) diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 46abf1c8..1148e0fe 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -67,7 +67,7 @@ def copy_model_files(c, config_file, out_path, new_fields): if isinstance(value, str): new_line = '"{}":"{}",\n'.format(key, value) else: - new_line = '"{}":{},\n'.format(key, value) + new_line = '"{}":{},\n'.format(key, json.dumps(value, ensure_ascii=False)) config_lines.insert(1, new_line) config_out_file = open(copy_config_path, "w") config_out_file.writelines(config_lines)