Fix find unique phonemes script (#1928)

* Fix find unique phonemes script

* Fix unit tests
pull/1946/head
Edresson Casanova 2022-09-08 05:17:35 -03:00 committed by GitHub
parent 3b7dff568a
commit 159eeeef64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 14 deletions

View File

@ -7,30 +7,25 @@ from tqdm.contrib.concurrent import process_map
from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
phonemizer = Gruut(language="en-us")
from TTS.tts.utils.text.phonemizers import Gruut
def compute_phonemes(item):
try:
text = item[0]
ph = phonemizer.phonemize(text).split("|")
except:
return []
return list(set(ph))
text = item["text"]
ph = phonemizer.phonemize(text).replace("|", "")
return set(list(ph))
def main():
# pylint: disable=W0601
global c
global c, phonemizer
# pylint: disable=bad-option-value
parser = argparse.ArgumentParser(
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
"""
Example runs:
python TTS/bin/find_unique_chars.py --config_path config.json
python TTS/bin/find_unique_phonemes.py --config_path config.json
""",
formatter_class=RawTextHelpFormatter,
)
@ -46,15 +41,24 @@ def main():
items = train_items + eval_items
print("Num items:", len(items))
is_lang_def = all(item["language"] for item in items)
language_list = [item["language"] for item in items]
is_lang_def = all(language_list)
if not c.phoneme_language or not is_lang_def:
raise ValueError("Phoneme language must be defined in config.")
if not language_list.count(language_list[0]) == len(language_list):
raise ValueError(
"Currently, just one phoneme language per config file is supported !! Please split the dataset config into different configs and run it individually for each language !!"
)
phonemizer = Gruut(language=language_list[0], keep_puncs=True)
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
phones = []
for ph in phonemes:
phones.extend(ph)
phones = set(phones)
lower_phones = filter(lambda c: c.islower(), phones)
phones_force_lower = [c.lower() for c in phones]

View File

@ -19,6 +19,7 @@ dataset_config_en = BaseDatasetConfig(
language="en",
)
"""
dataset_config_pt = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
@ -26,6 +27,7 @@ dataset_config_pt = BaseDatasetConfig(
path="tests/data/ljspeech",
language="pt-br",
)
"""
# pylint: disable=protected-access
class TestFindUniquePhonemes(unittest.TestCase):
@ -46,7 +48,7 @@ class TestFindUniquePhonemes(unittest.TestCase):
epochs=1,
print_step=1,
print_eval=True,
datasets=[dataset_config_en, dataset_config_pt],
datasets=[dataset_config_en],
)
config.save_json(config_path)
@ -70,7 +72,7 @@ class TestFindUniquePhonemes(unittest.TestCase):
epochs=1,
print_step=1,
print_eval=True,
datasets=[dataset_config_en, dataset_config_pt],
datasets=[dataset_config_en],
)
config.save_json(config_path)