diff --git a/datasets/datafeeder.py b/datasets/datafeeder.py index 3ff4f7b..4d8303f 100644 --- a/datasets/datafeeder.py +++ b/datasets/datafeeder.py @@ -5,7 +5,7 @@ import tensorflow as tf import threading import time import traceback -from util import cmudict, textinput +from text import cmudict, text_to_sequence from util.infolog import log @@ -21,6 +21,7 @@ class DataFeeder(threading.Thread): super(DataFeeder, self).__init__() self._coord = coordinator self._hparams = hparams + self._cleaner_names = [x.strip() for x in hparams.cleaners.split(',')] self._offset = 0 # Load metadata: @@ -107,15 +108,15 @@ class DataFeeder(threading.Thread): if self._cmudict and random.random() < _p_cmudict: text = ' '.join([self._maybe_get_arpabet(word) for word in text.split(' ')]) - input_data = np.asarray(textinput.to_sequence(text), dtype=np.int32) + input_data = np.asarray(text_to_sequence(text, self._cleaner_names), dtype=np.int32) linear_target = np.load(os.path.join(self._datadir, meta[0])) mel_target = np.load(os.path.join(self._datadir, meta[1])) return (input_data, mel_target, linear_target, len(linear_target)) def _maybe_get_arpabet(self, word): - pron = self._cmudict.lookup(word) - return '{%s}' % pron[0] if pron is not None and random.random() < 0.5 else word + arpabet = self._cmudict.lookup(word) + return '{%s}' % arpabet[0] if arpabet is not None and random.random() < 0.5 else word def _prepare_batch(batch, outputs_per_step): diff --git a/hparams.py b/hparams.py index 3f47119..3c4da5d 100644 --- a/hparams.py +++ b/hparams.py @@ -3,10 +3,9 @@ import tensorflow as tf # Default hyperparameters: hparams = tf.contrib.training.HParams( - # Text: - force_lowercase=True, - expand_abbreviations=True, - use_cmudict=False, + # Comma-separated list of cleaners to run on text prior to training and eval. For non-English + # text, you may want to use "basic_pipeline" or "transliteration_pipeline" See inputs/cleaners.py. + cleaners='english_pipeline', # Audio: num_mels=80, @@ -28,6 +27,7 @@ hparams = tf.contrib.training.HParams( adam_beta2=0.999, initial_learning_rate=0.002, decay_learning_rate=True, + use_cmudict=False, # Use CMUDict during training to learn pronunciation of ARPAbet phonemes # Eval: max_iters=200, diff --git a/models/tacotron.py b/models/tacotron.py index a2d7c21..5fce5f7 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -1,7 +1,7 @@ import tensorflow as tf from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, OutputProjectionWrapper, ResidualWrapper from tensorflow.contrib.seq2seq import BasicDecoder, BahdanauAttention, AttentionWrapper -from util import textinput +from text.symbols import symbols from util.infolog import log from .helpers import TacoTestHelper, TacoTrainingHelper from .modules import encoder_cbhg, post_cbhg, prenet @@ -38,7 +38,7 @@ class Tacotron(): # Embeddings embedding_table = tf.get_variable( - 'embedding', [textinput.num_symbols(), 256], dtype=tf.float32, + 'embedding', [len(symbols), 256], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5)) embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs) # [N, T_in, 256] diff --git a/requirements.txt b/requirements.txt index 87562e6..20401a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ scipy==0.19.0 tensorflow==1.2.0 tensorflow-gpu==1.2.0 tqdm==4.11.2 +Unidecode==0.4.20 diff --git a/synthesizer.py b/synthesizer.py index 96f2b5f..6904807 100644 --- a/synthesizer.py +++ b/synthesizer.py @@ -3,7 +3,8 @@ import numpy as np import tensorflow as tf from hparams import hparams from models import create_model -from util import audio, textinput +from text import text_to_sequence +from util import audio class Synthesizer: @@ -23,9 +24,8 @@ class Synthesizer: def synthesize(self, text): - seq = textinput.to_sequence(text, - force_lowercase=hparams.force_lowercase, - expand_abbreviations=hparams.expand_abbreviations) + cleaner_names = [x.strip() for x in hparams.cleaners.split(',')] + seq = text_to_sequence(text, cleaner_names) feed_dict = { self.model.inputs: [np.asarray(seq, dtype=np.int32)], self.model.input_lengths: np.asarray([len(seq)], dtype=np.int32) diff --git a/tests/cmudict_test.py b/tests/cmudict_test.py index e2c62b4..203d6e8 100644 --- a/tests/cmudict_test.py +++ b/tests/cmudict_test.py @@ -1,5 +1,5 @@ import io -from util import cmudict +from text import cmudict test_data = ''' @@ -37,5 +37,5 @@ def test_cmudict(): def test_cmudict_no_keep_ambiguous(): c = cmudict.CMUDict(io.StringIO(test_data), keep_ambiguous=False) assert len(c) == 5 - assert c.lookup('ADVERSITY') == ['AE0 D V ER1 S IH0 T IY2'] + assert c.lookup('adversity') == ['AE0 D V ER1 S IH0 T IY2'] assert c.lookup('adverse') == None diff --git a/tests/numbers_test.py b/tests/numbers_test.py index 61f631d..7fa6b60 100644 --- a/tests/numbers_test.py +++ b/tests/numbers_test.py @@ -1,51 +1,51 @@ -from util.numbers import normalize +from text.numbers import normalize_numbers def test_normalize_numbers(): - assert normalize('1') == 'one' - assert normalize('15') == 'fifteen' - assert normalize('24') == 'twenty-four' - assert normalize('100') == 'one hundred' - assert normalize('101') == 'one hundred one' - assert normalize('456') == 'four hundred fifty-six' - assert normalize('1000') == 'one thousand' - assert normalize('1800') == 'eighteen hundred' - assert normalize('2,000') == 'two thousand' - assert normalize('3000') == 'three thousand' - assert normalize('18000') == 'eighteen thousand' - assert normalize('24,000') == 'twenty-four thousand' - assert normalize('124,001') == 'one hundred twenty-four thousand one' - assert normalize('6.4 sec') == 'six point four sec' + assert normalize_numbers('1') == 'one' + assert normalize_numbers('15') == 'fifteen' + assert normalize_numbers('24') == 'twenty-four' + assert normalize_numbers('100') == 'one hundred' + assert normalize_numbers('101') == 'one hundred one' + assert normalize_numbers('456') == 'four hundred fifty-six' + assert normalize_numbers('1000') == 'one thousand' + assert normalize_numbers('1800') == 'eighteen hundred' + assert normalize_numbers('2,000') == 'two thousand' + assert normalize_numbers('3000') == 'three thousand' + assert normalize_numbers('18000') == 'eighteen thousand' + assert normalize_numbers('24,000') == 'twenty-four thousand' + assert normalize_numbers('124,001') == 'one hundred twenty-four thousand one' + assert normalize_numbers('6.4 sec') == 'six point four sec' def test_normalize_ordinals(): - assert normalize('1st') == 'first' - assert normalize('2nd') == 'second' - assert normalize('9th') == 'ninth' - assert normalize('243rd place') == 'two hundred and forty-third place' + assert normalize_numbers('1st') == 'first' + assert normalize_numbers('2nd') == 'second' + assert normalize_numbers('9th') == 'ninth' + assert normalize_numbers('243rd place') == 'two hundred and forty-third place' def test_normalize_dates(): - assert normalize('1400') == 'fourteen hundred' - assert normalize('1901') == 'nineteen oh one' - assert normalize('1999') == 'nineteen ninety-nine' - assert normalize('2000') == 'two thousand' - assert normalize('2004') == 'two thousand four' - assert normalize('2010') == 'twenty ten' - assert normalize('2012') == 'twenty twelve' - assert normalize('2025') == 'twenty twenty-five' - assert normalize('September 11, 2001') == 'September eleven, two thousand one' - assert normalize('July 26, 1984.') == 'July twenty-six, nineteen eighty-four.' + assert normalize_numbers('1400') == 'fourteen hundred' + assert normalize_numbers('1901') == 'nineteen oh one' + assert normalize_numbers('1999') == 'nineteen ninety-nine' + assert normalize_numbers('2000') == 'two thousand' + assert normalize_numbers('2004') == 'two thousand four' + assert normalize_numbers('2010') == 'twenty ten' + assert normalize_numbers('2012') == 'twenty twelve' + assert normalize_numbers('2025') == 'twenty twenty-five' + assert normalize_numbers('September 11, 2001') == 'September eleven, two thousand one' + assert normalize_numbers('July 26, 1984.') == 'July twenty-six, nineteen eighty-four.' def test_normalize_money(): - assert normalize('$0.00') == 'zero dollars' - assert normalize('$1') == 'one dollar' - assert normalize('$10') == 'ten dollars' - assert normalize('$.01') == 'one cent' - assert normalize('$0.25') == 'twenty-five cents' - assert normalize('$5.00') == 'five dollars' - assert normalize('$5.01') == 'five dollars, one cent' - assert normalize('$135.99.') == 'one hundred thirty-five dollars, ninety-nine cents.' - assert normalize('$40,000') == 'forty thousand dollars' - assert normalize('for £2500!') == 'for twenty-five hundred pounds!' + assert normalize_numbers('$0.00') == 'zero dollars' + assert normalize_numbers('$1') == 'one dollar' + assert normalize_numbers('$10') == 'ten dollars' + assert normalize_numbers('$.01') == 'one cent' + assert normalize_numbers('$0.25') == 'twenty-five cents' + assert normalize_numbers('$5.00') == 'five dollars' + assert normalize_numbers('$5.01') == 'five dollars, one cent' + assert normalize_numbers('$135.99.') == 'one hundred thirty-five dollars, ninety-nine cents.' + assert normalize_numbers('$40,000') == 'forty thousand dollars' + assert normalize_numbers('for £2500!') == 'for twenty-five hundred pounds!' diff --git a/tests/text_test.py b/tests/text_test.py new file mode 100644 index 0000000..a222ed6 --- /dev/null +++ b/tests/text_test.py @@ -0,0 +1,60 @@ +from text import cleaners, symbols, text_to_sequence, sequence_to_text +from unidecode import unidecode + + +def test_symbols(): + assert len(symbols) >= 3 + assert symbols[0] == '_' + assert symbols[1] == '~' + + +def test_text_to_sequence(): + assert text_to_sequence('', []) == [1] + assert text_to_sequence('Hi!', []) == [9, 36, 54, 1] + assert text_to_sequence('"A"_B', []) == [2, 3, 1] + assert text_to_sequence('A {AW1 S} B', []) == [2, 64, 83, 132, 64, 3, 1] + assert text_to_sequence('Hi', ['lowercase']) == [35, 36, 1] + assert text_to_sequence('A {AW1 S} B', ['english_pipeline']) == [28, 64, 83, 132, 64, 29, 1] + + +def test_sequence_to_text(): + assert sequence_to_text([]) == '' + assert sequence_to_text([1]) == '~' + assert sequence_to_text([9, 36, 54, 1]) == 'Hi!~' + assert sequence_to_text([2, 64, 83, 132, 64, 3]) == 'A {AW1 S} B' + + +def test_collapse_whitespace(): + assert cleaners.collapse_whitespace('') == '' + assert cleaners.collapse_whitespace(' ') == ' ' + assert cleaners.collapse_whitespace('x') == 'x' + assert cleaners.collapse_whitespace(' x. y, \tz') == ' x. y, z' + + +def test_convert_to_ascii(): + assert cleaners.convert_to_ascii("raison d'être") == "raison d'etre" + assert cleaners.convert_to_ascii('grüß gott') == 'gruss gott' + assert cleaners.convert_to_ascii('안녕') == 'annyeong' + assert cleaners.convert_to_ascii('Здравствуйте') == 'Zdravstvuite' + + +def test_lowercase(): + assert cleaners.lowercase('Happy Birthday!') == 'happy birthday!' + assert cleaners.lowercase('CAFÉ') == 'café' + + +def test_expand_abbreviations(): + assert cleaners.expand_abbreviations('mr. and mrs. smith') == 'mister and misess smith' + + +def test_expand_numbers(): + assert cleaners.expand_numbers('3 apples and 44 pears') == 'three apples and forty-four pears' + assert cleaners.expand_numbers('$3.50 for gas.') == 'three dollars, fifty cents for gas.' + + +def test_pipelines(): + text = 'Mr. Müller ate 2 Apples' + assert cleaners.english_pipeline(text) == 'mister muller ate two apples' + assert cleaners.transliteration_pipeline(text) == 'mr. muller ate 2 apples' + assert cleaners.basic_pipeline(text) == 'mr. müller ate 2 apples' + diff --git a/tests/textinput_test.py b/tests/textinput_test.py deleted file mode 100644 index b47143b..0000000 --- a/tests/textinput_test.py +++ /dev/null @@ -1,59 +0,0 @@ -from util.textinput import num_symbols, to_sequence, to_string - - -def text_num_symbols(): - assert num_symbols() == 147 - - -def test_to_sequence(): - assert to_sequence('') == [1] - assert to_sequence('H', force_lowercase=False) == [9, 1] - assert to_sequence('H', force_lowercase=True) == [35, 1] - assert to_sequence('Hi.', force_lowercase=False) == [9, 36, 60, 1] - - -def test_whitespace_nomalization(): - assert round_trip('') == '~' - assert round_trip(' ') == '~' - assert round_trip('x') == 'x~' - assert round_trip(' x ') == 'x~' - assert round_trip(' x. y,z ') == 'x. y,z~' - assert round_trip('X: Y') == 'X: Y~' - - -def test_valid_chars(): - assert round_trip('x') == 'x~' - assert round_trip('Hello') == 'Hello~' - assert round_trip('3 apples and 44 bananas') == 'three apples and forty-four bananas~' - assert round_trip('$3.50 for gas.') == 'three dollars, fifty cents for gas.~' - assert round_trip('Hello, world!') == 'Hello, world!~' - assert round_trip("What (time-out)! He\'s going where?") == "What (time-out)! He\'s going where?~" - - -def test_invalid_chars(): - assert round_trip('^') == ' ~' - assert round_trip('A~^B') == 'A B~' - assert round_trip('"Finally," she said, "it ended."') == 'Finally, she said, it ended.~' - - -def test_unicode(): - assert round_trip('naïve café') == 'naive cafe~' - assert round_trip("raison d'être") == "raison d'etre~" - - -def test_arpabet(): - assert to_sequence('{AE0 D}') == [70, 91, 1] - assert round_trip('{AE0 D V ER1 S}') == '{AE0 D V ER1 S}~' - assert round_trip('{AE0 D V ER1 S} circumstances') == '{AE0 D V ER1 S} circumstances~' - assert round_trip('In {AE0 D V ER1 S} circumstances') == 'In {AE0 D V ER1 S} circumstances~' - assert round_trip('{AE0 D V ER1 S} {AE0 D S}') == '{AE0 D V ER1 S} {AE0 D S}~' - assert round_trip('X {AE0 D} Y {AE0 D} Z') == 'X {AE0 D} Y {AE0 D} Z~' - - -def test_abbreviations(): - assert round_trip('mr. rogers and dr. smith.') == 'mister rogers and doctor smith.~' - assert round_trip('hit it with a hammr.') == 'hit it with a hammr.~' - - -def round_trip(x): - return to_string(to_sequence(x, force_lowercase=False, expand_abbreviations=True)) diff --git a/text/__init__.py b/text/__init__.py new file mode 100644 index 0000000..44e8ccd --- /dev/null +++ b/text/__init__.py @@ -0,0 +1,76 @@ +import re +from text import cleaners +from text.symbols import symbols + + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + +# Regular expression matching text enclosed in curly braces: +_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') + + +def text_to_sequence(text, cleaner_names): + '''Converts a string of text to a sequence of IDs for the symbols in the text. + + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [] + + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = _curly_re.match(text) + if not m: + sequence += _characters_to_sequence(_clean_text(text, cleaner_names)) + break + sequence += _characters_to_sequence(_clean_text(m.group(1), cleaner_names)) + sequence += _arpabet_to_sequence(m.group(2)) + text = m.group(3) + + # Append EOS token + sequence.append(_symbol_to_id['~']) + return sequence + + +def sequence_to_text(sequence): + '''Converts a sequence of IDs back to a string''' + result = '' + for symbol_id in sequence: + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] + # Enclose ARPAbet back in curly braces: + if len(s) > 1 and s[0] == '@': + s = '{%s}' % s[1:] + result += s + return result.replace('}{', ' ') + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception('Unknown cleaner: %s' % name) + text = cleaner(text) + return text + + +def _characters_to_sequence(chars): + return [_symbol_to_id[s] for s in chars if _should_keep_symbol(s)] + + +def _arpabet_to_sequence(text): + arpabet_symbols = ['@' + s for s in text.split()] + return [_symbol_to_id[s] for s in arpabet_symbols if _should_keep_symbol(s)] + + +def _should_keep_symbol(s): + return s in _symbol_to_id and s is not '_' and s is not '~' diff --git a/text/cleaners.py b/text/cleaners.py new file mode 100644 index 0000000..5eedca0 --- /dev/null +++ b/text/cleaners.py @@ -0,0 +1,88 @@ +''' +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_pipeline" for English text + 2. "transliteration_pipeline" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_pipeline" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +''' + +import re +from unidecode import unidecode +from .numbers import normalize_numbers + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_pipeline(text): + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_pipeline(text): + '''Pipeline for non-English text that transliterates to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_pipeline(text): + '''Pipeline for English text, including number and abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text diff --git a/util/cmudict.py b/text/cmudict.py similarity index 100% rename from util/cmudict.py rename to text/cmudict.py diff --git a/util/numbers.py b/text/numbers.py similarity index 98% rename from util/numbers.py rename to text/numbers.py index 6c5f547..ba9eb74 100644 --- a/util/numbers.py +++ b/text/numbers.py @@ -59,7 +59,7 @@ def _expand_number(m): return _inflect.number_to_words(num, andword='') -def normalize(text): +def normalize_numbers(text): text = re.sub(_comma_number_re, _remove_commas, text) text = re.sub(_pounds_re, r'\1 pounds', text) text = re.sub(_dollars_re, _expand_dollars, text) diff --git a/text/symbols.py b/text/symbols.py new file mode 100644 index 0000000..ab3e64e --- /dev/null +++ b/text/symbols.py @@ -0,0 +1,18 @@ +''' +Defines the set of symbols used in text input to the model. + +The default works well for English. For non-English datasets, update _characters to be the set of +characters in the dataset. The "cleaners" hyperparameter should also be changed to be +"basic_pipeline" or a custom set of steps for the dataset (see cleaners.py for more info). +''' +from text import cmudict + +_pad = '_' +_eos = '~' +_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' + +# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): +_arpabet = ['@' + s for s in cmudict.valid_symbols] + +# Export all symbols: +symbols = [_pad, _eos] + list(_characters) + _arpabet diff --git a/train.py b/train.py index 7b49cb5..4bfc5ca 100644 --- a/train.py +++ b/train.py @@ -11,7 +11,8 @@ import traceback from datasets.datafeeder import DataFeeder from hparams import hparams, hparams_debug_string from models import create_model -from util import audio, infolog, plot, textinput, ValueWindow +from text import sequence_to_text +from util import audio, infolog, plot, ValueWindow log = infolog.log @@ -114,7 +115,7 @@ def train(log_dir, args): audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss)) - log('Input: %s' % textinput.to_string(input_seq)) + log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) diff --git a/util/textinput.py b/util/textinput.py deleted file mode 100644 index 7b010f4..0000000 --- a/util/textinput.py +++ /dev/null @@ -1,105 +0,0 @@ -import re -import unicodedata -from util import cmudict, numbers - - -# Input alphabet (63 symbols), plus ARPAbet (84 symbols): -_pad = '_' -_eos = '~' -_uppercase = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' -_lowercase = 'abcdefghijklmnopqrstuvwxyz' -_punctuation = '!\'(),-.:;?' -_space = ' ' - -_valid_input_chars = _uppercase + _lowercase + _punctuation + _space -_trans_table = str.maketrans({chr(i): ' ' for i in range(256) if chr(i) not in _valid_input_chars}) - -_normal_symbols = _pad + _eos + _valid_input_chars -_num_normal_symbols = len(_normal_symbols) -_char_to_id = {c: i for i, c in enumerate(_normal_symbols)} -_id_to_char = {i: c for i, c in enumerate(_normal_symbols)} -_arpabet_to_id = {sym: i + _num_normal_symbols for i, sym in enumerate(cmudict.valid_symbols)} -_id_to_arpabet = {i + _num_normal_symbols: sym for i, sym in enumerate(cmudict.valid_symbols)} -_arpabet_re = re.compile(r'(.*?)\{([A-Z0-2 ]+?)\}(.*)') -_num_symbols = _num_normal_symbols + len(cmudict.valid_symbols) -_whitespace_re = re.compile(r'\s+') - - -def num_symbols(): - '''Returns number of symbols in the alphabet.''' - return _num_symbols - - -def to_sequence(text, force_lowercase=True, expand_abbreviations=True): - '''Converts a string of text to a sequence of IDs for the symbols in the text''' - text = text.strip() - text = text.replace('"', '') - text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode() - - sequence = [] - while len(text): - m = _arpabet_re.match(text) - if not m: - sequence += _text_to_sequence(text, force_lowercase, expand_abbreviations) - break - sequence += _text_to_sequence(m.group(1), force_lowercase, expand_abbreviations) - sequence += _arpabet_to_sequence(m.group(2)) - text = m.group(3) - sequence.append(_char_to_id[_eos]) - return sequence - - -def to_string(sequence, remove_eos=False): - '''Returns the string for a sequence of characters.''' - s = '' - for sym in sequence: - if sym < _num_normal_symbols: - s += _id_to_char[sym] - elif sym < _num_symbols: - s += '{%s}' % _id_to_arpabet[sym] - s = s.replace('}{', ' ') - if remove_eos and s[-1] == _eos: - s = s[:-1] - return s - - -def _text_to_sequence(text, force_lowercase, expand_abbreviations): - text = numbers.normalize(text) - text = text.translate(_trans_table) - if force_lowercase: - text = text.lower() - if expand_abbreviations: - text = _expand_abbreviations(text) - text = re.sub(_whitespace_re, ' ', text) - return [_char_to_id[c] for c in text] - - -_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), -]] - -def _expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text - - -def _arpabet_to_sequence(text): - return [_arpabet_to_id[s] for s in text.split() if s in _arpabet_to_id]