From 421787277f9bde691e62a872789ef84b3a48390f Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 21 Nov 2018 18:01:35 +0100 Subject: [PATCH] phonem updates --- .compute | 2 +- datasets/TTSDataset.py | 2 +- utils/text/__init__.py | 27 ++++++++++++++++++++++++--- utils/text/cmudict.py | 25 +++++++++++++++++++------ utils/text/symbols.py | 4 +++- 5 files changed, 48 insertions(+), 12 deletions(-) diff --git a/.compute b/.compute index 6a7438cb..f3c91c23 100644 --- a/.compute +++ b/.compute @@ -2,4 +2,4 @@ source ../tmp/venv/bin/activate # python extract_features.py --data_path ${DATA_ROOT}/shared/data/keithito/LJSpeech-1.1/ --cache_path ~/tts_cache/ --config config.json --num_proc 12 --dataset ljspeech --meta_file metadata.csv --val_split 1000 --process_audio true # python train.py --config_path config.json --data_path ~/tts_cache/ --debug true -python train.py --config_path config.json --data_path ${DATA_ROOT}/shared/data/Blizzard/Nancy/ --debug true \ No newline at end of file +python train.py --config_path config.json --data_path ${DATA_ROOT}/shared/data/Blizzard/Nancy/ --debug true diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index d6282de3..2dad08ee 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -6,7 +6,7 @@ import torch import random from torch.utils.data import Dataset -from utils.text import text_to_sequence, phoneme_to_sequence +from utils.text import text_to_sequence, phonem_to_sequence from utils.data import (prepare_data, pad_per_step, prepare_tensor, prepare_stop_target) diff --git a/utils/text/__init__.py b/utils/text/__init__.py index ed4b6e3a..faf9018e 100644 --- a/utils/text/__init__.py +++ b/utils/text/__init__.py @@ -3,6 +3,7 @@ import re from utils.text import cleaners from utils.text.symbols import symbols, phonemes +from utils.text.cmudict import text2phone # Mappings from symbol to numeric ID and vice versa: _symbol_to_id = {s: i for i, s in enumerate(symbols)} @@ -15,9 +16,18 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)} _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') -def phoneme_to_sequence(text, cleaner_names): +def phonem_to_sequence(text, cleaner_names): + ''' + TODO: This ignores punctuations + ''' sequence = [] - sequence += _phonem_to_sequence(_clean_text(text, cleaner_names)) + clean_text = _clean_text(text, cleaner_names) + for word in clean_text.split(): + phonems_text = text2phone(word) + if phonems_text == None: + continue + sequence += _phonem_to_sequence(phonems_text) + sequence.append(_phonemes_to_id[' ']) sequence.append(_phonemes_to_id['~']) return sequence @@ -66,6 +76,17 @@ def sequence_to_text(sequence): return result.replace('}{', ' ') +def sequence_to_phonem(sequence): + '''Converts a sequence of IDs back to a string''' + result = '' + for symbol_id in sequence: + if symbol_id in _id_to_phonemes: + s = _id_to_phonemes[symbol_id] + print(s) + result += s + return result.replace('}{', ' ') + + def _clean_text(text, cleaner_names): for name in cleaner_names: cleaner = getattr(cleaners, name) @@ -80,7 +101,7 @@ def _symbols_to_sequence(symbols): def _phonem_to_sequence(phonemes): - return [_phonemes_to_id[s] for s in phonemes if _should_keep_phonem(s)] + return [_phonemes_to_id[s] for s in phonemes.split(" ") if _should_keep_phonem(s)] def _arpabet_to_sequence(text): diff --git a/utils/text/cmudict.py b/utils/text/cmudict.py index c74076cb..044387a1 100644 --- a/utils/text/cmudict.py +++ b/utils/text/cmudict.py @@ -64,12 +64,11 @@ _phonemes = set(_phonemes) def text2phone(text): seperator = phonemizer.separator.Separator('', '', ' ') - ph = phonemizer.phonemize(text, separator=seperator) - ph = ph.split(' ') - ph.remove('') - - result = [char2code[p] for p in ph] - return result + try: + ph = phonemizer.phonemize(text, separator=seperator) + except: + ph = None + return ph class CMUDict: @@ -95,6 +94,20 @@ class CMUDict: '''Returns list of ARPAbet pronunciations of the given word.''' return self._entries.get(word.upper()) + def get_arpabet(self, word, cmudict, punctuation_symbols): + first_symbol, last_symbol = '', '' + if len(word) > 0 and word[0] in punctuation_symbols: + first_symbol = word[0] + word = word[1:] + if len(word) > 0 and word[-1] in punctuation_symbols: + last_symbol = word[-1] + word = word[:-1] + arpabet = cmudict.lookup(word) + if arpabet is not None: + return first_symbol + '{%s}' % arpabet[0] + last_symbol + else: + return first_symbol + word + last_symbol + _alt_re = re.compile(r'\([0-9]+\)') diff --git a/utils/text/symbols.py b/utils/text/symbols.py index 4dc8814d..d629ef3a 100644 --- a/utils/text/symbols.py +++ b/utils/text/symbols.py @@ -10,13 +10,15 @@ from utils.text import cmudict _pad = '_' _eos = '~' _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' +_punctuations = '!\'(),-.:;? ' # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): _arpabet = ['@' + s for s in cmudict._phonemes] # Export all symbols: symbols = [_pad, _eos] + list(_characters) + _arpabet -phonemes = [_pad, _eos] + cmudict._phonemes + list('!\'(),-.:;?') +phonemes = [_pad, _eos] + list(cmudict._phonemes) + list(_punctuations) if __name__ == '__main__': print(symbols) + print(phonemes)