Simplify text processing, make it easier to adapt to non-English data.

pull/2/head
Keith Ito 2017-09-04 14:46:24 -07:00
parent 0894cd9bf2
commit 479976b6c5
16 changed files with 303 additions and 222 deletions

View File

@ -5,7 +5,7 @@ import tensorflow as tf
import threading
import time
import traceback
from util import cmudict, textinput
from text import cmudict, text_to_sequence
from util.infolog import log
@ -21,6 +21,7 @@ class DataFeeder(threading.Thread):
super(DataFeeder, self).__init__()
self._coord = coordinator
self._hparams = hparams
self._cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
self._offset = 0
# Load metadata:
@ -107,15 +108,15 @@ class DataFeeder(threading.Thread):
if self._cmudict and random.random() < _p_cmudict:
text = ' '.join([self._maybe_get_arpabet(word) for word in text.split(' ')])
input_data = np.asarray(textinput.to_sequence(text), dtype=np.int32)
input_data = np.asarray(text_to_sequence(text, self._cleaner_names), dtype=np.int32)
linear_target = np.load(os.path.join(self._datadir, meta[0]))
mel_target = np.load(os.path.join(self._datadir, meta[1]))
return (input_data, mel_target, linear_target, len(linear_target))
def _maybe_get_arpabet(self, word):
pron = self._cmudict.lookup(word)
return '{%s}' % pron[0] if pron is not None and random.random() < 0.5 else word
arpabet = self._cmudict.lookup(word)
return '{%s}' % arpabet[0] if arpabet is not None and random.random() < 0.5 else word
def _prepare_batch(batch, outputs_per_step):

View File

@ -3,10 +3,9 @@ import tensorflow as tf
# Default hyperparameters:
hparams = tf.contrib.training.HParams(
# Text:
force_lowercase=True,
expand_abbreviations=True,
use_cmudict=False,
# Comma-separated list of cleaners to run on text prior to training and eval. For non-English
# text, you may want to use "basic_pipeline" or "transliteration_pipeline" See inputs/cleaners.py.
cleaners='english_pipeline',
# Audio:
num_mels=80,
@ -28,6 +27,7 @@ hparams = tf.contrib.training.HParams(
adam_beta2=0.999,
initial_learning_rate=0.002,
decay_learning_rate=True,
use_cmudict=False, # Use CMUDict during training to learn pronunciation of ARPAbet phonemes
# Eval:
max_iters=200,

View File

@ -1,7 +1,7 @@
import tensorflow as tf
from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, OutputProjectionWrapper, ResidualWrapper
from tensorflow.contrib.seq2seq import BasicDecoder, BahdanauAttention, AttentionWrapper
from util import textinput
from text.symbols import symbols
from util.infolog import log
from .helpers import TacoTestHelper, TacoTrainingHelper
from .modules import encoder_cbhg, post_cbhg, prenet
@ -38,7 +38,7 @@ class Tacotron():
# Embeddings
embedding_table = tf.get_variable(
'embedding', [textinput.num_symbols(), 256], dtype=tf.float32,
'embedding', [len(symbols), 256], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.5))
embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs) # [N, T_in, 256]

View File

@ -7,3 +7,4 @@ scipy==0.19.0
tensorflow==1.2.0
tensorflow-gpu==1.2.0
tqdm==4.11.2
Unidecode==0.4.20

View File

@ -3,7 +3,8 @@ import numpy as np
import tensorflow as tf
from hparams import hparams
from models import create_model
from util import audio, textinput
from text import text_to_sequence
from util import audio
class Synthesizer:
@ -23,9 +24,8 @@ class Synthesizer:
def synthesize(self, text):
seq = textinput.to_sequence(text,
force_lowercase=hparams.force_lowercase,
expand_abbreviations=hparams.expand_abbreviations)
cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
seq = text_to_sequence(text, cleaner_names)
feed_dict = {
self.model.inputs: [np.asarray(seq, dtype=np.int32)],
self.model.input_lengths: np.asarray([len(seq)], dtype=np.int32)

View File

@ -1,5 +1,5 @@
import io
from util import cmudict
from text import cmudict
test_data = '''
@ -37,5 +37,5 @@ def test_cmudict():
def test_cmudict_no_keep_ambiguous():
c = cmudict.CMUDict(io.StringIO(test_data), keep_ambiguous=False)
assert len(c) == 5
assert c.lookup('ADVERSITY') == ['AE0 D V ER1 S IH0 T IY2']
assert c.lookup('adversity') == ['AE0 D V ER1 S IH0 T IY2']
assert c.lookup('adverse') == None

View File

@ -1,51 +1,51 @@
from util.numbers import normalize
from text.numbers import normalize_numbers
def test_normalize_numbers():
assert normalize('1') == 'one'
assert normalize('15') == 'fifteen'
assert normalize('24') == 'twenty-four'
assert normalize('100') == 'one hundred'
assert normalize('101') == 'one hundred one'
assert normalize('456') == 'four hundred fifty-six'
assert normalize('1000') == 'one thousand'
assert normalize('1800') == 'eighteen hundred'
assert normalize('2,000') == 'two thousand'
assert normalize('3000') == 'three thousand'
assert normalize('18000') == 'eighteen thousand'
assert normalize('24,000') == 'twenty-four thousand'
assert normalize('124,001') == 'one hundred twenty-four thousand one'
assert normalize('6.4 sec') == 'six point four sec'
assert normalize_numbers('1') == 'one'
assert normalize_numbers('15') == 'fifteen'
assert normalize_numbers('24') == 'twenty-four'
assert normalize_numbers('100') == 'one hundred'
assert normalize_numbers('101') == 'one hundred one'
assert normalize_numbers('456') == 'four hundred fifty-six'
assert normalize_numbers('1000') == 'one thousand'
assert normalize_numbers('1800') == 'eighteen hundred'
assert normalize_numbers('2,000') == 'two thousand'
assert normalize_numbers('3000') == 'three thousand'
assert normalize_numbers('18000') == 'eighteen thousand'
assert normalize_numbers('24,000') == 'twenty-four thousand'
assert normalize_numbers('124,001') == 'one hundred twenty-four thousand one'
assert normalize_numbers('6.4 sec') == 'six point four sec'
def test_normalize_ordinals():
assert normalize('1st') == 'first'
assert normalize('2nd') == 'second'
assert normalize('9th') == 'ninth'
assert normalize('243rd place') == 'two hundred and forty-third place'
assert normalize_numbers('1st') == 'first'
assert normalize_numbers('2nd') == 'second'
assert normalize_numbers('9th') == 'ninth'
assert normalize_numbers('243rd place') == 'two hundred and forty-third place'
def test_normalize_dates():
assert normalize('1400') == 'fourteen hundred'
assert normalize('1901') == 'nineteen oh one'
assert normalize('1999') == 'nineteen ninety-nine'
assert normalize('2000') == 'two thousand'
assert normalize('2004') == 'two thousand four'
assert normalize('2010') == 'twenty ten'
assert normalize('2012') == 'twenty twelve'
assert normalize('2025') == 'twenty twenty-five'
assert normalize('September 11, 2001') == 'September eleven, two thousand one'
assert normalize('July 26, 1984.') == 'July twenty-six, nineteen eighty-four.'
assert normalize_numbers('1400') == 'fourteen hundred'
assert normalize_numbers('1901') == 'nineteen oh one'
assert normalize_numbers('1999') == 'nineteen ninety-nine'
assert normalize_numbers('2000') == 'two thousand'
assert normalize_numbers('2004') == 'two thousand four'
assert normalize_numbers('2010') == 'twenty ten'
assert normalize_numbers('2012') == 'twenty twelve'
assert normalize_numbers('2025') == 'twenty twenty-five'
assert normalize_numbers('September 11, 2001') == 'September eleven, two thousand one'
assert normalize_numbers('July 26, 1984.') == 'July twenty-six, nineteen eighty-four.'
def test_normalize_money():
assert normalize('$0.00') == 'zero dollars'
assert normalize('$1') == 'one dollar'
assert normalize('$10') == 'ten dollars'
assert normalize('$.01') == 'one cent'
assert normalize('$0.25') == 'twenty-five cents'
assert normalize('$5.00') == 'five dollars'
assert normalize('$5.01') == 'five dollars, one cent'
assert normalize('$135.99.') == 'one hundred thirty-five dollars, ninety-nine cents.'
assert normalize('$40,000') == 'forty thousand dollars'
assert normalize('for £2500!') == 'for twenty-five hundred pounds!'
assert normalize_numbers('$0.00') == 'zero dollars'
assert normalize_numbers('$1') == 'one dollar'
assert normalize_numbers('$10') == 'ten dollars'
assert normalize_numbers('$.01') == 'one cent'
assert normalize_numbers('$0.25') == 'twenty-five cents'
assert normalize_numbers('$5.00') == 'five dollars'
assert normalize_numbers('$5.01') == 'five dollars, one cent'
assert normalize_numbers('$135.99.') == 'one hundred thirty-five dollars, ninety-nine cents.'
assert normalize_numbers('$40,000') == 'forty thousand dollars'
assert normalize_numbers('for £2500!') == 'for twenty-five hundred pounds!'

60
tests/text_test.py Normal file
View File

@ -0,0 +1,60 @@
from text import cleaners, symbols, text_to_sequence, sequence_to_text
from unidecode import unidecode
def test_symbols():
assert len(symbols) >= 3
assert symbols[0] == '_'
assert symbols[1] == '~'
def test_text_to_sequence():
assert text_to_sequence('', []) == [1]
assert text_to_sequence('Hi!', []) == [9, 36, 54, 1]
assert text_to_sequence('"A"_B', []) == [2, 3, 1]
assert text_to_sequence('A {AW1 S} B', []) == [2, 64, 83, 132, 64, 3, 1]
assert text_to_sequence('Hi', ['lowercase']) == [35, 36, 1]
assert text_to_sequence('A {AW1 S} B', ['english_pipeline']) == [28, 64, 83, 132, 64, 29, 1]
def test_sequence_to_text():
assert sequence_to_text([]) == ''
assert sequence_to_text([1]) == '~'
assert sequence_to_text([9, 36, 54, 1]) == 'Hi!~'
assert sequence_to_text([2, 64, 83, 132, 64, 3]) == 'A {AW1 S} B'
def test_collapse_whitespace():
assert cleaners.collapse_whitespace('') == ''
assert cleaners.collapse_whitespace(' ') == ' '
assert cleaners.collapse_whitespace('x') == 'x'
assert cleaners.collapse_whitespace(' x. y, \tz') == ' x. y, z'
def test_convert_to_ascii():
assert cleaners.convert_to_ascii("raison d'être") == "raison d'etre"
assert cleaners.convert_to_ascii('grüß gott') == 'gruss gott'
assert cleaners.convert_to_ascii('안녕') == 'annyeong'
assert cleaners.convert_to_ascii('Здравствуйте') == 'Zdravstvuite'
def test_lowercase():
assert cleaners.lowercase('Happy Birthday!') == 'happy birthday!'
assert cleaners.lowercase('CAFÉ') == 'café'
def test_expand_abbreviations():
assert cleaners.expand_abbreviations('mr. and mrs. smith') == 'mister and misess smith'
def test_expand_numbers():
assert cleaners.expand_numbers('3 apples and 44 pears') == 'three apples and forty-four pears'
assert cleaners.expand_numbers('$3.50 for gas.') == 'three dollars, fifty cents for gas.'
def test_pipelines():
text = 'Mr. Müller ate 2 Apples'
assert cleaners.english_pipeline(text) == 'mister muller ate two apples'
assert cleaners.transliteration_pipeline(text) == 'mr. muller ate 2 apples'
assert cleaners.basic_pipeline(text) == 'mr. müller ate 2 apples'

View File

@ -1,59 +0,0 @@
from util.textinput import num_symbols, to_sequence, to_string
def text_num_symbols():
assert num_symbols() == 147
def test_to_sequence():
assert to_sequence('') == [1]
assert to_sequence('H', force_lowercase=False) == [9, 1]
assert to_sequence('H', force_lowercase=True) == [35, 1]
assert to_sequence('Hi.', force_lowercase=False) == [9, 36, 60, 1]
def test_whitespace_nomalization():
assert round_trip('') == '~'
assert round_trip(' ') == '~'
assert round_trip('x') == 'x~'
assert round_trip(' x ') == 'x~'
assert round_trip(' x. y,z ') == 'x. y,z~'
assert round_trip('X: Y') == 'X: Y~'
def test_valid_chars():
assert round_trip('x') == 'x~'
assert round_trip('Hello') == 'Hello~'
assert round_trip('3 apples and 44 bananas') == 'three apples and forty-four bananas~'
assert round_trip('$3.50 for gas.') == 'three dollars, fifty cents for gas.~'
assert round_trip('Hello, world!') == 'Hello, world!~'
assert round_trip("What (time-out)! He\'s going where?") == "What (time-out)! He\'s going where?~"
def test_invalid_chars():
assert round_trip('^') == ' ~'
assert round_trip('A~^B') == 'A B~'
assert round_trip('"Finally," she said, "it ended."') == 'Finally, she said, it ended.~'
def test_unicode():
assert round_trip('naïve café') == 'naive cafe~'
assert round_trip("raison d'être") == "raison d'etre~"
def test_arpabet():
assert to_sequence('{AE0 D}') == [70, 91, 1]
assert round_trip('{AE0 D V ER1 S}') == '{AE0 D V ER1 S}~'
assert round_trip('{AE0 D V ER1 S} circumstances') == '{AE0 D V ER1 S} circumstances~'
assert round_trip('In {AE0 D V ER1 S} circumstances') == 'In {AE0 D V ER1 S} circumstances~'
assert round_trip('{AE0 D V ER1 S} {AE0 D S}') == '{AE0 D V ER1 S} {AE0 D S}~'
assert round_trip('X {AE0 D} Y {AE0 D} Z') == 'X {AE0 D} Y {AE0 D} Z~'
def test_abbreviations():
assert round_trip('mr. rogers and dr. smith.') == 'mister rogers and doctor smith.~'
assert round_trip('hit it with a hammr.') == 'hit it with a hammr.~'
def round_trip(x):
return to_string(to_sequence(x, force_lowercase=False, expand_abbreviations=True))

76
text/__init__.py Normal file
View File

@ -0,0 +1,76 @@
import re
from text import cleaners
from text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
def text_to_sequence(text, cleaner_names):
'''Converts a string of text to a sequence of IDs for the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _characters_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _characters_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# Append EOS token
sequence.append(_symbol_to_id['~'])
return sequence
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
result += s
return result.replace('}{', ' ')
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
def _characters_to_sequence(chars):
return [_symbol_to_id[s] for s in chars if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
arpabet_symbols = ['@' + s for s in text.split()]
return [_symbol_to_id[s] for s in arpabet_symbols if _should_keep_symbol(s)]
def _should_keep_symbol(s):
return s in _symbol_to_id and s is not '_' and s is not '~'

88
text/cleaners.py Normal file
View File

@ -0,0 +1,88 @@
'''
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_pipeline" for English text
2. "transliteration_pipeline" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_pipeline" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
'''
import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def basic_pipeline(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_pipeline(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_pipeline(text):
'''Pipeline for English text, including number and abbreviation expansion.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text

View File

@ -59,7 +59,7 @@ def _expand_number(m):
return _inflect.number_to_words(num, andword='')
def normalize(text):
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)

18
text/symbols.py Normal file
View File

@ -0,0 +1,18 @@
'''
Defines the set of symbols used in text input to the model.
The default works well for English. For non-English datasets, update _characters to be the set of
characters in the dataset. The "cleaners" hyperparameter should also be changed to be
"basic_pipeline" or a custom set of steps for the dataset (see cleaners.py for more info).
'''
from text import cmudict
_pad = '_'
_eos = '~'
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in cmudict.valid_symbols]
# Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet

View File

@ -11,7 +11,8 @@ import traceback
from datasets.datafeeder import DataFeeder
from hparams import hparams, hparams_debug_string
from models import create_model
from util import audio, infolog, plot, textinput, ValueWindow
from text import sequence_to_text
from util import audio, infolog, plot, ValueWindow
log = infolog.log
@ -114,7 +115,7 @@ def train(log_dir, args):
audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step))
plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step),
info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss))
log('Input: %s' % textinput.to_string(input_seq))
log('Input: %s' % sequence_to_text(input_seq))
except Exception as e:
log('Exiting due to exception: %s' % e, slack=True)

View File

@ -1,105 +0,0 @@
import re
import unicodedata
from util import cmudict, numbers
# Input alphabet (63 symbols), plus ARPAbet (84 symbols):
_pad = '_'
_eos = '~'
_uppercase = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
_lowercase = 'abcdefghijklmnopqrstuvwxyz'
_punctuation = '!\'(),-.:;?'
_space = ' '
_valid_input_chars = _uppercase + _lowercase + _punctuation + _space
_trans_table = str.maketrans({chr(i): ' ' for i in range(256) if chr(i) not in _valid_input_chars})
_normal_symbols = _pad + _eos + _valid_input_chars
_num_normal_symbols = len(_normal_symbols)
_char_to_id = {c: i for i, c in enumerate(_normal_symbols)}
_id_to_char = {i: c for i, c in enumerate(_normal_symbols)}
_arpabet_to_id = {sym: i + _num_normal_symbols for i, sym in enumerate(cmudict.valid_symbols)}
_id_to_arpabet = {i + _num_normal_symbols: sym for i, sym in enumerate(cmudict.valid_symbols)}
_arpabet_re = re.compile(r'(.*?)\{([A-Z0-2 ]+?)\}(.*)')
_num_symbols = _num_normal_symbols + len(cmudict.valid_symbols)
_whitespace_re = re.compile(r'\s+')
def num_symbols():
'''Returns number of symbols in the alphabet.'''
return _num_symbols
def to_sequence(text, force_lowercase=True, expand_abbreviations=True):
'''Converts a string of text to a sequence of IDs for the symbols in the text'''
text = text.strip()
text = text.replace('"', '')
text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode()
sequence = []
while len(text):
m = _arpabet_re.match(text)
if not m:
sequence += _text_to_sequence(text, force_lowercase, expand_abbreviations)
break
sequence += _text_to_sequence(m.group(1), force_lowercase, expand_abbreviations)
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
sequence.append(_char_to_id[_eos])
return sequence
def to_string(sequence, remove_eos=False):
'''Returns the string for a sequence of characters.'''
s = ''
for sym in sequence:
if sym < _num_normal_symbols:
s += _id_to_char[sym]
elif sym < _num_symbols:
s += '{%s}' % _id_to_arpabet[sym]
s = s.replace('}{', ' ')
if remove_eos and s[-1] == _eos:
s = s[:-1]
return s
def _text_to_sequence(text, force_lowercase, expand_abbreviations):
text = numbers.normalize(text)
text = text.translate(_trans_table)
if force_lowercase:
text = text.lower()
if expand_abbreviations:
text = _expand_abbreviations(text)
text = re.sub(_whitespace_re, ' ', text)
return [_char_to_id[c] for c in text]
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def _expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def _arpabet_to_sequence(text):
return [_arpabet_to_id[s] for s in text.split() if s in _arpabet_to_id]