mirror of https://github.com/MycroftAI/mimic2.git
Simplify text processing, make it easier to adapt to non-English data.
parent
0894cd9bf2
commit
479976b6c5
|
@ -5,7 +5,7 @@ import tensorflow as tf
|
|||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from util import cmudict, textinput
|
||||
from text import cmudict, text_to_sequence
|
||||
from util.infolog import log
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ class DataFeeder(threading.Thread):
|
|||
super(DataFeeder, self).__init__()
|
||||
self._coord = coordinator
|
||||
self._hparams = hparams
|
||||
self._cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
|
||||
self._offset = 0
|
||||
|
||||
# Load metadata:
|
||||
|
@ -107,15 +108,15 @@ class DataFeeder(threading.Thread):
|
|||
if self._cmudict and random.random() < _p_cmudict:
|
||||
text = ' '.join([self._maybe_get_arpabet(word) for word in text.split(' ')])
|
||||
|
||||
input_data = np.asarray(textinput.to_sequence(text), dtype=np.int32)
|
||||
input_data = np.asarray(text_to_sequence(text, self._cleaner_names), dtype=np.int32)
|
||||
linear_target = np.load(os.path.join(self._datadir, meta[0]))
|
||||
mel_target = np.load(os.path.join(self._datadir, meta[1]))
|
||||
return (input_data, mel_target, linear_target, len(linear_target))
|
||||
|
||||
|
||||
def _maybe_get_arpabet(self, word):
|
||||
pron = self._cmudict.lookup(word)
|
||||
return '{%s}' % pron[0] if pron is not None and random.random() < 0.5 else word
|
||||
arpabet = self._cmudict.lookup(word)
|
||||
return '{%s}' % arpabet[0] if arpabet is not None and random.random() < 0.5 else word
|
||||
|
||||
|
||||
def _prepare_batch(batch, outputs_per_step):
|
||||
|
|
|
@ -3,10 +3,9 @@ import tensorflow as tf
|
|||
|
||||
# Default hyperparameters:
|
||||
hparams = tf.contrib.training.HParams(
|
||||
# Text:
|
||||
force_lowercase=True,
|
||||
expand_abbreviations=True,
|
||||
use_cmudict=False,
|
||||
# Comma-separated list of cleaners to run on text prior to training and eval. For non-English
|
||||
# text, you may want to use "basic_pipeline" or "transliteration_pipeline" See inputs/cleaners.py.
|
||||
cleaners='english_pipeline',
|
||||
|
||||
# Audio:
|
||||
num_mels=80,
|
||||
|
@ -28,6 +27,7 @@ hparams = tf.contrib.training.HParams(
|
|||
adam_beta2=0.999,
|
||||
initial_learning_rate=0.002,
|
||||
decay_learning_rate=True,
|
||||
use_cmudict=False, # Use CMUDict during training to learn pronunciation of ARPAbet phonemes
|
||||
|
||||
# Eval:
|
||||
max_iters=200,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, OutputProjectionWrapper, ResidualWrapper
|
||||
from tensorflow.contrib.seq2seq import BasicDecoder, BahdanauAttention, AttentionWrapper
|
||||
from util import textinput
|
||||
from text.symbols import symbols
|
||||
from util.infolog import log
|
||||
from .helpers import TacoTestHelper, TacoTrainingHelper
|
||||
from .modules import encoder_cbhg, post_cbhg, prenet
|
||||
|
@ -38,7 +38,7 @@ class Tacotron():
|
|||
|
||||
# Embeddings
|
||||
embedding_table = tf.get_variable(
|
||||
'embedding', [textinput.num_symbols(), 256], dtype=tf.float32,
|
||||
'embedding', [len(symbols), 256], dtype=tf.float32,
|
||||
initializer=tf.truncated_normal_initializer(stddev=0.5))
|
||||
embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs) # [N, T_in, 256]
|
||||
|
||||
|
|
|
@ -7,3 +7,4 @@ scipy==0.19.0
|
|||
tensorflow==1.2.0
|
||||
tensorflow-gpu==1.2.0
|
||||
tqdm==4.11.2
|
||||
Unidecode==0.4.20
|
||||
|
|
|
@ -3,7 +3,8 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
from hparams import hparams
|
||||
from models import create_model
|
||||
from util import audio, textinput
|
||||
from text import text_to_sequence
|
||||
from util import audio
|
||||
|
||||
|
||||
class Synthesizer:
|
||||
|
@ -23,9 +24,8 @@ class Synthesizer:
|
|||
|
||||
|
||||
def synthesize(self, text):
|
||||
seq = textinput.to_sequence(text,
|
||||
force_lowercase=hparams.force_lowercase,
|
||||
expand_abbreviations=hparams.expand_abbreviations)
|
||||
cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
|
||||
seq = text_to_sequence(text, cleaner_names)
|
||||
feed_dict = {
|
||||
self.model.inputs: [np.asarray(seq, dtype=np.int32)],
|
||||
self.model.input_lengths: np.asarray([len(seq)], dtype=np.int32)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import io
|
||||
from util import cmudict
|
||||
from text import cmudict
|
||||
|
||||
|
||||
test_data = '''
|
||||
|
@ -37,5 +37,5 @@ def test_cmudict():
|
|||
def test_cmudict_no_keep_ambiguous():
|
||||
c = cmudict.CMUDict(io.StringIO(test_data), keep_ambiguous=False)
|
||||
assert len(c) == 5
|
||||
assert c.lookup('ADVERSITY') == ['AE0 D V ER1 S IH0 T IY2']
|
||||
assert c.lookup('adversity') == ['AE0 D V ER1 S IH0 T IY2']
|
||||
assert c.lookup('adverse') == None
|
||||
|
|
|
@ -1,51 +1,51 @@
|
|||
from util.numbers import normalize
|
||||
from text.numbers import normalize_numbers
|
||||
|
||||
|
||||
def test_normalize_numbers():
|
||||
assert normalize('1') == 'one'
|
||||
assert normalize('15') == 'fifteen'
|
||||
assert normalize('24') == 'twenty-four'
|
||||
assert normalize('100') == 'one hundred'
|
||||
assert normalize('101') == 'one hundred one'
|
||||
assert normalize('456') == 'four hundred fifty-six'
|
||||
assert normalize('1000') == 'one thousand'
|
||||
assert normalize('1800') == 'eighteen hundred'
|
||||
assert normalize('2,000') == 'two thousand'
|
||||
assert normalize('3000') == 'three thousand'
|
||||
assert normalize('18000') == 'eighteen thousand'
|
||||
assert normalize('24,000') == 'twenty-four thousand'
|
||||
assert normalize('124,001') == 'one hundred twenty-four thousand one'
|
||||
assert normalize('6.4 sec') == 'six point four sec'
|
||||
assert normalize_numbers('1') == 'one'
|
||||
assert normalize_numbers('15') == 'fifteen'
|
||||
assert normalize_numbers('24') == 'twenty-four'
|
||||
assert normalize_numbers('100') == 'one hundred'
|
||||
assert normalize_numbers('101') == 'one hundred one'
|
||||
assert normalize_numbers('456') == 'four hundred fifty-six'
|
||||
assert normalize_numbers('1000') == 'one thousand'
|
||||
assert normalize_numbers('1800') == 'eighteen hundred'
|
||||
assert normalize_numbers('2,000') == 'two thousand'
|
||||
assert normalize_numbers('3000') == 'three thousand'
|
||||
assert normalize_numbers('18000') == 'eighteen thousand'
|
||||
assert normalize_numbers('24,000') == 'twenty-four thousand'
|
||||
assert normalize_numbers('124,001') == 'one hundred twenty-four thousand one'
|
||||
assert normalize_numbers('6.4 sec') == 'six point four sec'
|
||||
|
||||
|
||||
def test_normalize_ordinals():
|
||||
assert normalize('1st') == 'first'
|
||||
assert normalize('2nd') == 'second'
|
||||
assert normalize('9th') == 'ninth'
|
||||
assert normalize('243rd place') == 'two hundred and forty-third place'
|
||||
assert normalize_numbers('1st') == 'first'
|
||||
assert normalize_numbers('2nd') == 'second'
|
||||
assert normalize_numbers('9th') == 'ninth'
|
||||
assert normalize_numbers('243rd place') == 'two hundred and forty-third place'
|
||||
|
||||
|
||||
def test_normalize_dates():
|
||||
assert normalize('1400') == 'fourteen hundred'
|
||||
assert normalize('1901') == 'nineteen oh one'
|
||||
assert normalize('1999') == 'nineteen ninety-nine'
|
||||
assert normalize('2000') == 'two thousand'
|
||||
assert normalize('2004') == 'two thousand four'
|
||||
assert normalize('2010') == 'twenty ten'
|
||||
assert normalize('2012') == 'twenty twelve'
|
||||
assert normalize('2025') == 'twenty twenty-five'
|
||||
assert normalize('September 11, 2001') == 'September eleven, two thousand one'
|
||||
assert normalize('July 26, 1984.') == 'July twenty-six, nineteen eighty-four.'
|
||||
assert normalize_numbers('1400') == 'fourteen hundred'
|
||||
assert normalize_numbers('1901') == 'nineteen oh one'
|
||||
assert normalize_numbers('1999') == 'nineteen ninety-nine'
|
||||
assert normalize_numbers('2000') == 'two thousand'
|
||||
assert normalize_numbers('2004') == 'two thousand four'
|
||||
assert normalize_numbers('2010') == 'twenty ten'
|
||||
assert normalize_numbers('2012') == 'twenty twelve'
|
||||
assert normalize_numbers('2025') == 'twenty twenty-five'
|
||||
assert normalize_numbers('September 11, 2001') == 'September eleven, two thousand one'
|
||||
assert normalize_numbers('July 26, 1984.') == 'July twenty-six, nineteen eighty-four.'
|
||||
|
||||
|
||||
def test_normalize_money():
|
||||
assert normalize('$0.00') == 'zero dollars'
|
||||
assert normalize('$1') == 'one dollar'
|
||||
assert normalize('$10') == 'ten dollars'
|
||||
assert normalize('$.01') == 'one cent'
|
||||
assert normalize('$0.25') == 'twenty-five cents'
|
||||
assert normalize('$5.00') == 'five dollars'
|
||||
assert normalize('$5.01') == 'five dollars, one cent'
|
||||
assert normalize('$135.99.') == 'one hundred thirty-five dollars, ninety-nine cents.'
|
||||
assert normalize('$40,000') == 'forty thousand dollars'
|
||||
assert normalize('for £2500!') == 'for twenty-five hundred pounds!'
|
||||
assert normalize_numbers('$0.00') == 'zero dollars'
|
||||
assert normalize_numbers('$1') == 'one dollar'
|
||||
assert normalize_numbers('$10') == 'ten dollars'
|
||||
assert normalize_numbers('$.01') == 'one cent'
|
||||
assert normalize_numbers('$0.25') == 'twenty-five cents'
|
||||
assert normalize_numbers('$5.00') == 'five dollars'
|
||||
assert normalize_numbers('$5.01') == 'five dollars, one cent'
|
||||
assert normalize_numbers('$135.99.') == 'one hundred thirty-five dollars, ninety-nine cents.'
|
||||
assert normalize_numbers('$40,000') == 'forty thousand dollars'
|
||||
assert normalize_numbers('for £2500!') == 'for twenty-five hundred pounds!'
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
from text import cleaners, symbols, text_to_sequence, sequence_to_text
|
||||
from unidecode import unidecode
|
||||
|
||||
|
||||
def test_symbols():
|
||||
assert len(symbols) >= 3
|
||||
assert symbols[0] == '_'
|
||||
assert symbols[1] == '~'
|
||||
|
||||
|
||||
def test_text_to_sequence():
|
||||
assert text_to_sequence('', []) == [1]
|
||||
assert text_to_sequence('Hi!', []) == [9, 36, 54, 1]
|
||||
assert text_to_sequence('"A"_B', []) == [2, 3, 1]
|
||||
assert text_to_sequence('A {AW1 S} B', []) == [2, 64, 83, 132, 64, 3, 1]
|
||||
assert text_to_sequence('Hi', ['lowercase']) == [35, 36, 1]
|
||||
assert text_to_sequence('A {AW1 S} B', ['english_pipeline']) == [28, 64, 83, 132, 64, 29, 1]
|
||||
|
||||
|
||||
def test_sequence_to_text():
|
||||
assert sequence_to_text([]) == ''
|
||||
assert sequence_to_text([1]) == '~'
|
||||
assert sequence_to_text([9, 36, 54, 1]) == 'Hi!~'
|
||||
assert sequence_to_text([2, 64, 83, 132, 64, 3]) == 'A {AW1 S} B'
|
||||
|
||||
|
||||
def test_collapse_whitespace():
|
||||
assert cleaners.collapse_whitespace('') == ''
|
||||
assert cleaners.collapse_whitespace(' ') == ' '
|
||||
assert cleaners.collapse_whitespace('x') == 'x'
|
||||
assert cleaners.collapse_whitespace(' x. y, \tz') == ' x. y, z'
|
||||
|
||||
|
||||
def test_convert_to_ascii():
|
||||
assert cleaners.convert_to_ascii("raison d'être") == "raison d'etre"
|
||||
assert cleaners.convert_to_ascii('grüß gott') == 'gruss gott'
|
||||
assert cleaners.convert_to_ascii('안녕') == 'annyeong'
|
||||
assert cleaners.convert_to_ascii('Здравствуйте') == 'Zdravstvuite'
|
||||
|
||||
|
||||
def test_lowercase():
|
||||
assert cleaners.lowercase('Happy Birthday!') == 'happy birthday!'
|
||||
assert cleaners.lowercase('CAFÉ') == 'café'
|
||||
|
||||
|
||||
def test_expand_abbreviations():
|
||||
assert cleaners.expand_abbreviations('mr. and mrs. smith') == 'mister and misess smith'
|
||||
|
||||
|
||||
def test_expand_numbers():
|
||||
assert cleaners.expand_numbers('3 apples and 44 pears') == 'three apples and forty-four pears'
|
||||
assert cleaners.expand_numbers('$3.50 for gas.') == 'three dollars, fifty cents for gas.'
|
||||
|
||||
|
||||
def test_pipelines():
|
||||
text = 'Mr. Müller ate 2 Apples'
|
||||
assert cleaners.english_pipeline(text) == 'mister muller ate two apples'
|
||||
assert cleaners.transliteration_pipeline(text) == 'mr. muller ate 2 apples'
|
||||
assert cleaners.basic_pipeline(text) == 'mr. müller ate 2 apples'
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
from util.textinput import num_symbols, to_sequence, to_string
|
||||
|
||||
|
||||
def text_num_symbols():
|
||||
assert num_symbols() == 147
|
||||
|
||||
|
||||
def test_to_sequence():
|
||||
assert to_sequence('') == [1]
|
||||
assert to_sequence('H', force_lowercase=False) == [9, 1]
|
||||
assert to_sequence('H', force_lowercase=True) == [35, 1]
|
||||
assert to_sequence('Hi.', force_lowercase=False) == [9, 36, 60, 1]
|
||||
|
||||
|
||||
def test_whitespace_nomalization():
|
||||
assert round_trip('') == '~'
|
||||
assert round_trip(' ') == '~'
|
||||
assert round_trip('x') == 'x~'
|
||||
assert round_trip(' x ') == 'x~'
|
||||
assert round_trip(' x. y,z ') == 'x. y,z~'
|
||||
assert round_trip('X: Y') == 'X: Y~'
|
||||
|
||||
|
||||
def test_valid_chars():
|
||||
assert round_trip('x') == 'x~'
|
||||
assert round_trip('Hello') == 'Hello~'
|
||||
assert round_trip('3 apples and 44 bananas') == 'three apples and forty-four bananas~'
|
||||
assert round_trip('$3.50 for gas.') == 'three dollars, fifty cents for gas.~'
|
||||
assert round_trip('Hello, world!') == 'Hello, world!~'
|
||||
assert round_trip("What (time-out)! He\'s going where?") == "What (time-out)! He\'s going where?~"
|
||||
|
||||
|
||||
def test_invalid_chars():
|
||||
assert round_trip('^') == ' ~'
|
||||
assert round_trip('A~^B') == 'A B~'
|
||||
assert round_trip('"Finally," she said, "it ended."') == 'Finally, she said, it ended.~'
|
||||
|
||||
|
||||
def test_unicode():
|
||||
assert round_trip('naïve café') == 'naive cafe~'
|
||||
assert round_trip("raison d'être") == "raison d'etre~"
|
||||
|
||||
|
||||
def test_arpabet():
|
||||
assert to_sequence('{AE0 D}') == [70, 91, 1]
|
||||
assert round_trip('{AE0 D V ER1 S}') == '{AE0 D V ER1 S}~'
|
||||
assert round_trip('{AE0 D V ER1 S} circumstances') == '{AE0 D V ER1 S} circumstances~'
|
||||
assert round_trip('In {AE0 D V ER1 S} circumstances') == 'In {AE0 D V ER1 S} circumstances~'
|
||||
assert round_trip('{AE0 D V ER1 S} {AE0 D S}') == '{AE0 D V ER1 S} {AE0 D S}~'
|
||||
assert round_trip('X {AE0 D} Y {AE0 D} Z') == 'X {AE0 D} Y {AE0 D} Z~'
|
||||
|
||||
|
||||
def test_abbreviations():
|
||||
assert round_trip('mr. rogers and dr. smith.') == 'mister rogers and doctor smith.~'
|
||||
assert round_trip('hit it with a hammr.') == 'hit it with a hammr.~'
|
||||
|
||||
|
||||
def round_trip(x):
|
||||
return to_string(to_sequence(x, force_lowercase=False, expand_abbreviations=True))
|
|
@ -0,0 +1,76 @@
|
|||
import re
|
||||
from text import cleaners
|
||||
from text.symbols import symbols
|
||||
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names):
|
||||
'''Converts a string of text to a sequence of IDs for the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
sequence = []
|
||||
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
if not m:
|
||||
sequence += _characters_to_sequence(_clean_text(text, cleaner_names))
|
||||
break
|
||||
sequence += _characters_to_sequence(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
# Append EOS token
|
||||
sequence.append(_symbol_to_id['~'])
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = '{%s}' % s[1:]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception('Unknown cleaner: %s' % name)
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
||||
def _characters_to_sequence(chars):
|
||||
return [_symbol_to_id[s] for s in chars if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
arpabet_symbols = ['@' + s for s in text.split()]
|
||||
return [_symbol_to_id[s] for s in arpabet_symbols if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s is not '_' and s is not '~'
|
|
@ -0,0 +1,88 @@
|
|||
'''
|
||||
Cleaners are transformations that run over the input text at both training and eval time.
|
||||
|
||||
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
||||
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
||||
1. "english_pipeline" for English text
|
||||
2. "transliteration_pipeline" for non-English text that can be transliterated to ASCII using
|
||||
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||
3. "basic_pipeline" if you do not want to transliterate (in this case, you should also update
|
||||
the symbols in symbols.py to match your data).
|
||||
'''
|
||||
|
||||
import re
|
||||
from unidecode import unidecode
|
||||
from .numbers import normalize_numbers
|
||||
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_pipeline(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_pipeline(text):
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def english_pipeline(text):
|
||||
'''Pipeline for English text, including number and abbreviation expansion.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
|
@ -59,7 +59,7 @@ def _expand_number(m):
|
|||
return _inflect.number_to_words(num, andword='')
|
||||
|
||||
|
||||
def normalize(text):
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
|
@ -0,0 +1,18 @@
|
|||
'''
|
||||
Defines the set of symbols used in text input to the model.
|
||||
|
||||
The default works well for English. For non-English datasets, update _characters to be the set of
|
||||
characters in the dataset. The "cleaners" hyperparameter should also be changed to be
|
||||
"basic_pipeline" or a custom set of steps for the dataset (see cleaners.py for more info).
|
||||
'''
|
||||
from text import cmudict
|
||||
|
||||
_pad = '_'
|
||||
_eos = '~'
|
||||
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
|
||||
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
_arpabet = ['@' + s for s in cmudict.valid_symbols]
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad, _eos] + list(_characters) + _arpabet
|
5
train.py
5
train.py
|
@ -11,7 +11,8 @@ import traceback
|
|||
from datasets.datafeeder import DataFeeder
|
||||
from hparams import hparams, hparams_debug_string
|
||||
from models import create_model
|
||||
from util import audio, infolog, plot, textinput, ValueWindow
|
||||
from text import sequence_to_text
|
||||
from util import audio, infolog, plot, ValueWindow
|
||||
log = infolog.log
|
||||
|
||||
|
||||
|
@ -114,7 +115,7 @@ def train(log_dir, args):
|
|||
audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step))
|
||||
plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step),
|
||||
info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss))
|
||||
log('Input: %s' % textinput.to_string(input_seq))
|
||||
log('Input: %s' % sequence_to_text(input_seq))
|
||||
|
||||
except Exception as e:
|
||||
log('Exiting due to exception: %s' % e, slack=True)
|
||||
|
|
|
@ -1,105 +0,0 @@
|
|||
import re
|
||||
import unicodedata
|
||||
from util import cmudict, numbers
|
||||
|
||||
|
||||
# Input alphabet (63 symbols), plus ARPAbet (84 symbols):
|
||||
_pad = '_'
|
||||
_eos = '~'
|
||||
_uppercase = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
_lowercase = 'abcdefghijklmnopqrstuvwxyz'
|
||||
_punctuation = '!\'(),-.:;?'
|
||||
_space = ' '
|
||||
|
||||
_valid_input_chars = _uppercase + _lowercase + _punctuation + _space
|
||||
_trans_table = str.maketrans({chr(i): ' ' for i in range(256) if chr(i) not in _valid_input_chars})
|
||||
|
||||
_normal_symbols = _pad + _eos + _valid_input_chars
|
||||
_num_normal_symbols = len(_normal_symbols)
|
||||
_char_to_id = {c: i for i, c in enumerate(_normal_symbols)}
|
||||
_id_to_char = {i: c for i, c in enumerate(_normal_symbols)}
|
||||
_arpabet_to_id = {sym: i + _num_normal_symbols for i, sym in enumerate(cmudict.valid_symbols)}
|
||||
_id_to_arpabet = {i + _num_normal_symbols: sym for i, sym in enumerate(cmudict.valid_symbols)}
|
||||
_arpabet_re = re.compile(r'(.*?)\{([A-Z0-2 ]+?)\}(.*)')
|
||||
_num_symbols = _num_normal_symbols + len(cmudict.valid_symbols)
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
|
||||
def num_symbols():
|
||||
'''Returns number of symbols in the alphabet.'''
|
||||
return _num_symbols
|
||||
|
||||
|
||||
def to_sequence(text, force_lowercase=True, expand_abbreviations=True):
|
||||
'''Converts a string of text to a sequence of IDs for the symbols in the text'''
|
||||
text = text.strip()
|
||||
text = text.replace('"', '')
|
||||
text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode()
|
||||
|
||||
sequence = []
|
||||
while len(text):
|
||||
m = _arpabet_re.match(text)
|
||||
if not m:
|
||||
sequence += _text_to_sequence(text, force_lowercase, expand_abbreviations)
|
||||
break
|
||||
sequence += _text_to_sequence(m.group(1), force_lowercase, expand_abbreviations)
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
sequence.append(_char_to_id[_eos])
|
||||
return sequence
|
||||
|
||||
|
||||
def to_string(sequence, remove_eos=False):
|
||||
'''Returns the string for a sequence of characters.'''
|
||||
s = ''
|
||||
for sym in sequence:
|
||||
if sym < _num_normal_symbols:
|
||||
s += _id_to_char[sym]
|
||||
elif sym < _num_symbols:
|
||||
s += '{%s}' % _id_to_arpabet[sym]
|
||||
s = s.replace('}{', ' ')
|
||||
if remove_eos and s[-1] == _eos:
|
||||
s = s[:-1]
|
||||
return s
|
||||
|
||||
|
||||
def _text_to_sequence(text, force_lowercase, expand_abbreviations):
|
||||
text = numbers.normalize(text)
|
||||
text = text.translate(_trans_table)
|
||||
if force_lowercase:
|
||||
text = text.lower()
|
||||
if expand_abbreviations:
|
||||
text = _expand_abbreviations(text)
|
||||
text = re.sub(_whitespace_re, ' ', text)
|
||||
return [_char_to_id[c] for c in text]
|
||||
|
||||
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
|
||||
def _expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return [_arpabet_to_id[s] for s in text.split() if s in _arpabet_to_id]
|
Loading…
Reference in New Issue