diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index cccd65a2..d649bf23 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -77,13 +77,12 @@ class MyDataset(Dataset): def _generate_and_cache_phoneme_sequence(self, text, cache_path): """generate a phoneme sequence from text. - since the usage is for subsequent caching, we never add bos and eos chars here. Instead we add those dynamically later; based on the config option.""" phonemes = phoneme_to_sequence(text, [self.cleaners], language=self.phoneme_language, - enable_eos_bos=False, + enable_eos_bos=False, tp=self.tp) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) diff --git a/train.py b/train.py index 96c268f0..616d54ac 100644 --- a/train.py +++ b/train.py @@ -519,9 +519,8 @@ def main(args): # pylint: disable=redefined-outer-name global meta_data_train, meta_data_eval, symbols, phonemes # Audio processor ap = AudioProcessor(**c.audio) - if 'text' in c.keys(): - symbols, phonemes = make_symbols(**c.text) + symbols, phonemes = make_symbols(**c.text) # DISTRUBUTED if num_gpus > 1: diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 6aecdc7d..7c2f033a 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -426,13 +426,13 @@ def check_config(c): _check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) # vocabulary parameters - _check_argument('text', c, restricted=False, val_type=dict) # parameter not mandatory - _check_argument('pad', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) # mandatory if "text parameters" else no mandatory - _check_argument('eos', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) - _check_argument('bos', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) - _check_argument('characters', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) - _check_argument('phonemes', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) - _check_argument('punctuations', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) + _check_argument('text', c, restricted=False, val_type=dict) + _check_argument('pad', c['text'] if 'text' in c.keys() else {}, restricted='text' in c.keys(), val_type=str) + _check_argument('eos', c['text'] if 'text' in c.keys() else {}, restricted='text' in c.keys(), val_type=str) + _check_argument('bos', c['text'] if 'text' in c.keys() else {}, restricted='text' in c.keys(), val_type=str) + _check_argument('characters', c['text'] if 'text' in c.keys() else {}, restricted='text' in c.keys(), val_type=str) + _check_argument('phonemes', c['text'] if 'text' in c.keys() else {}, restricted='text' in c.keys(), val_type=str) + _check_argument('punctuations', c['text'] if 'text' in c.keys() else {}, restricted='text' in c.keys(), val_type=str) # normalization parameters _check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) diff --git a/utils/text/__init__.py b/utils/text/__init__.py index fcb239b2..ff21ffe0 100644 --- a/utils/text/__init__.py +++ b/utils/text/__init__.py @@ -61,8 +61,8 @@ def pad_with_eos_bos(phoneme_sequence, tp=None): if tp: _bos = tp['bos'] _eos = tp['eos'] - _, phonemes = make_symbols(**tp) - _PHONEMES_TO_ID = {s: i for i, s in enumerate(phonemes)} + _, _phonemes = make_symbols(**tp) + _PHONEMES_TO_ID = {s: i for i, s in enumerate(_phonemes)} return [_PHONEMES_TO_ID[_bos]] + list(phoneme_sequence) + [_PHONEMES_TO_ID[_eos]] @@ -70,8 +70,8 @@ def pad_with_eos_bos(phoneme_sequence, tp=None): def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None): global _PHONEMES_TO_ID if tp: - _, phonemes = make_symbols(**tp) - _PHONEMES_TO_ID = {s: i for i, s in enumerate(phonemes)} + _, _phonemes = make_symbols(**tp) + _PHONEMES_TO_ID = {s: i for i, s in enumerate(_phonemes)} sequence = [] text = text.replace(":", "") @@ -93,8 +93,8 @@ def sequence_to_phoneme(sequence, tp=None): global _ID_TO_PHONEMES result = '' if tp: - _, phonemes = make_symbols(**tp) - _ID_TO_PHONEMES = {i: s for i, s in enumerate(phonemes)} + _, _phonemes = make_symbols(**tp) + _ID_TO_PHONEMES = {i: s for i, s in enumerate(_phonemes)} for symbol_id in sequence: if symbol_id in _ID_TO_PHONEMES: @@ -118,8 +118,8 @@ def text_to_sequence(text, cleaner_names, tp=None): ''' global _SYMBOL_TO_ID if tp: - symbols, _ = make_symbols(**tp) - _SYMBOL_TO_ID = {s: i for i, s in enumerate(symbols)} + _symbols, _ = make_symbols(**tp) + _SYMBOL_TO_ID = {s: i for i, s in enumerate(_symbols)} sequence = [] # Check for curly braces and treat their contents as ARPAbet: @@ -139,8 +139,8 @@ def sequence_to_text(sequence, tp=None): '''Converts a sequence of IDs back to a string''' global _ID_TO_SYMBOL if tp: - symbols, _ = make_symbols(**tp) - _ID_TO_SYMBOL = {i: s for i, s in enumerate(symbols)} + _symbols, _ = make_symbols(**tp) + _ID_TO_SYMBOL = {i: s for i, s in enumerate(_symbols)} result = '' for symbol_id in sequence: diff --git a/utils/text/symbols.py b/utils/text/symbols.py index e4a4b103..db83cb29 100644 --- a/utils/text/symbols.py +++ b/utils/text/symbols.py @@ -5,16 +5,16 @@ Defines the set of symbols used in text input to the model. The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' -def make_symbols(characters, phonemes, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'): +def make_symbols(characters, phnms, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'): ''' Function to create symbols and phonemes ''' - _phonemes = sorted(list(phonemes)) + _phonemes_sorted = sorted(list(phnms)) # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): - _arpabet = ['@' + s for s in _phonemes] + _arpabet = ['@' + s for s in _phonemes_sorted] # Export all symbols: - symbols = [pad, eos, bos] + list(characters) + _arpabet - phonemes = [pad, eos, bos] + list(_phonemes) + list(punctuations) + _symbols = [pad, eos, bos] + list(characters) + _arpabet + _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) return symbols, phonemes