phonem updates

pull/10/head
Eren Golge 2018-11-21 18:01:35 +01:00
parent da30c3c9b3
commit 421787277f
5 changed files with 48 additions and 12 deletions

View File

@ -2,4 +2,4 @@
source ../tmp/venv/bin/activate
# python extract_features.py --data_path ${DATA_ROOT}/shared/data/keithito/LJSpeech-1.1/ --cache_path ~/tts_cache/ --config config.json --num_proc 12 --dataset ljspeech --meta_file metadata.csv --val_split 1000 --process_audio true
# python train.py --config_path config.json --data_path ~/tts_cache/ --debug true
python train.py --config_path config.json --data_path ${DATA_ROOT}/shared/data/Blizzard/Nancy/ --debug true
python train.py --config_path config.json --data_path ${DATA_ROOT}/shared/data/Blizzard/Nancy/ --debug true

View File

@ -6,7 +6,7 @@ import torch
import random
from torch.utils.data import Dataset
from utils.text import text_to_sequence, phoneme_to_sequence
from utils.text import text_to_sequence, phonem_to_sequence
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)

View File

@ -3,6 +3,7 @@
import re
from utils.text import cleaners
from utils.text.symbols import symbols, phonemes
from utils.text.cmudict import text2phone
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
@ -15,9 +16,18 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
def phoneme_to_sequence(text, cleaner_names):
def phonem_to_sequence(text, cleaner_names):
'''
TODO: This ignores punctuations
'''
sequence = []
sequence += _phonem_to_sequence(_clean_text(text, cleaner_names))
clean_text = _clean_text(text, cleaner_names)
for word in clean_text.split():
phonems_text = text2phone(word)
if phonems_text == None:
continue
sequence += _phonem_to_sequence(phonems_text)
sequence.append(_phonemes_to_id[' '])
sequence.append(_phonemes_to_id['~'])
return sequence
@ -66,6 +76,17 @@ def sequence_to_text(sequence):
return result.replace('}{', ' ')
def sequence_to_phonem(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if symbol_id in _id_to_phonemes:
s = _id_to_phonemes[symbol_id]
print(s)
result += s
return result.replace('}{', ' ')
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
@ -80,7 +101,7 @@ def _symbols_to_sequence(symbols):
def _phonem_to_sequence(phonemes):
return [_phonemes_to_id[s] for s in phonemes if _should_keep_phonem(s)]
return [_phonemes_to_id[s] for s in phonemes.split(" ") if _should_keep_phonem(s)]
def _arpabet_to_sequence(text):

View File

@ -64,12 +64,11 @@ _phonemes = set(_phonemes)
def text2phone(text):
seperator = phonemizer.separator.Separator('', '', ' ')
ph = phonemizer.phonemize(text, separator=seperator)
ph = ph.split(' ')
ph.remove('')
result = [char2code[p] for p in ph]
return result
try:
ph = phonemizer.phonemize(text, separator=seperator)
except:
ph = None
return ph
class CMUDict:
@ -95,6 +94,20 @@ class CMUDict:
'''Returns list of ARPAbet pronunciations of the given word.'''
return self._entries.get(word.upper())
def get_arpabet(self, word, cmudict, punctuation_symbols):
first_symbol, last_symbol = '', ''
if len(word) > 0 and word[0] in punctuation_symbols:
first_symbol = word[0]
word = word[1:]
if len(word) > 0 and word[-1] in punctuation_symbols:
last_symbol = word[-1]
word = word[:-1]
arpabet = cmudict.lookup(word)
if arpabet is not None:
return first_symbol + '{%s}' % arpabet[0] + last_symbol
else:
return first_symbol + word + last_symbol
_alt_re = re.compile(r'\([0-9]+\)')

View File

@ -10,13 +10,15 @@ from utils.text import cmudict
_pad = '_'
_eos = '~'
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
_punctuations = '!\'(),-.:;? '
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in cmudict._phonemes]
# Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet
phonemes = [_pad, _eos] + cmudict._phonemes + list('!\'(),-.:;?')
phonemes = [_pad, _eos] + list(cmudict._phonemes) + list(_punctuations)
if __name__ == '__main__':
print(symbols)
print(phonemes)