From 0cc3650ef61e8d304305f6a39e33a2876be9baa1 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 14 Nov 2020 00:13:53 -0800 Subject: [PATCH] support loading config in yaml --- TTS/bin/train_encoder.py | 1 + TTS/utils/io.py | 24 +++++++++++++++++------- requirements.txt | 1 + 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 8d1f14fa..078f7b84 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -212,6 +212,7 @@ if __name__ == '__main__': parser.add_argument( '--config_path', type=str, + required=True, help='Path to config file for training.', ) parser.add_argument('--debug', diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 074a730e..3cc36e95 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -1,5 +1,7 @@ +import os import re import json +import yaml import pickle as pickle_tts @@ -17,19 +19,27 @@ class AttrDict(dict): self.__dict__ = self -def load_config(config_path): +def load_config(config_path: str) -> AttrDict: """Load config files and discard comments Args: config_path (str): path to config file. """ config = AttrDict() - with open(config_path, "r") as f: - input_str = f.read() - # handle comments - input_str = re.sub(r'\\\n', '', input_str) - input_str = re.sub(r'//.*\n', '\n', input_str) - data = json.loads(input_str) + + ext = os.path.splitext(config_path)[1] + if ext in (".yml", ".yaml"): + with open(config_path, "r") as f: + data = yaml.safe_load(f) + else: + # fallback to json + with open(config_path, "r") as f: + input_str = f.read() + # handle comments + input_str = re.sub(r'\\\n', '', input_str) + input_str = re.sub(r'//.*\n', '\n', input_str) + data = json.loads(input_str) + config.update(data) return config diff --git a/requirements.txt b/requirements.txt index dda9dcb5..e8064926 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ pylint==2.5.3 gdown umap cython +pyyaml