mirror of https://github.com/coqui-ai/TTS.git
commit
fca2388e6a
|
@ -212,6 +212,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to config file for training.',
|
||||
)
|
||||
parser.add_argument('--debug',
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import os
|
||||
import re
|
||||
import json
|
||||
import yaml
|
||||
import pickle as pickle_tts
|
||||
|
||||
|
||||
|
@ -17,19 +19,27 @@ class AttrDict(dict):
|
|||
self.__dict__ = self
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
def load_config(config_path: str) -> AttrDict:
|
||||
"""Load config files and discard comments
|
||||
|
||||
Args:
|
||||
config_path (str): path to config file.
|
||||
"""
|
||||
config = AttrDict()
|
||||
with open(config_path, "r") as f:
|
||||
input_str = f.read()
|
||||
# handle comments
|
||||
input_str = re.sub(r'\\\n', '', input_str)
|
||||
input_str = re.sub(r'//.*\n', '\n', input_str)
|
||||
data = json.loads(input_str)
|
||||
|
||||
ext = os.path.splitext(config_path)[1]
|
||||
if ext in (".yml", ".yaml"):
|
||||
with open(config_path, "r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
else:
|
||||
# fallback to json
|
||||
with open(config_path, "r") as f:
|
||||
input_str = f.read()
|
||||
# handle comments
|
||||
input_str = re.sub(r'\\\n', '', input_str)
|
||||
input_str = re.sub(r'//.*\n', '\n', input_str)
|
||||
data = json.loads(input_str)
|
||||
|
||||
config.update(data)
|
||||
return config
|
||||
|
||||
|
|
|
@ -23,3 +23,4 @@ pylint==2.5.3
|
|||
gdown
|
||||
umap
|
||||
cython
|
||||
pyyaml
|
||||
|
|
Loading…
Reference in New Issue