Enable `custom_symbols` in text processing

Models can define their own custom symbols lists with custom
`make_symbols()`
pull/718/head
Eren Gölge 2021-08-07 21:46:10 +00:00
parent bd4e29b4dd
commit 003e5579e8
5 changed files with 134 additions and 36 deletions

View File

@ -23,7 +23,9 @@ class TTSDataset(Dataset):
ap: AudioProcessor,
meta_data: List[List],
characters: Dict = None,
custom_symbols: List = None,
add_blank: bool = False,
return_wav: bool = False,
batch_group_size: int = 0,
min_seq_len: int = 0,
max_seq_len: int = float("inf"),
@ -54,9 +56,14 @@ class TTSDataset(Dataset):
characters (dict): `dict` of custom text characters used for converting texts to sequences.
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
set of symbols need to pass it here. Defaults to `None`.
add_blank (bool): Add a special `blank` character after every other character. It helps some
models achieve better results. Defaults to false.
return_wav (bool): Return the waveform of the sample. Defaults to False.
batch_group_size (int): Range of batch randomization after sorting
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
batch. Set 0 to disable. Defaults to 0.
@ -95,10 +102,12 @@ class TTSDataset(Dataset):
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
self.characters = characters
self.custom_symbols = custom_symbols
self.add_blank = add_blank
self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path
@ -109,6 +118,7 @@ class TTSDataset(Dataset):
self.use_noise_augment = use_noise_augment
self.verbose = verbose
self.input_seq_computed = False
self.rescue_item_idx = 1
if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
if self.verbose:
@ -128,13 +138,21 @@ class TTSDataset(Dataset):
return data
@staticmethod
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank):
def _generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
):
"""generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the
config option."""
phonemes = phoneme_to_sequence(
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank
text,
[cleaners],
language=language,
enable_eos_bos=False,
custom_symbols=custom_symbols,
tp=characters,
add_blank=add_blank,
)
phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes)
@ -142,7 +160,7 @@ class TTSDataset(Dataset):
@staticmethod
def _load_or_generate_phoneme_sequence(
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
@ -153,12 +171,12 @@ class TTSDataset(Dataset):
phonemes = np.load(cache_path)
except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, characters, add_blank
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, characters, add_blank
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
)
if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=characters)
@ -189,13 +207,19 @@ class TTSDataset(Dataset):
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
)
else:
text = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32,
)
@ -209,7 +233,7 @@ class TTSDataset(Dataset):
# return a different sample if the phonemized
# text is longer than the threshold
# TODO: find a better fix
return self.load_data(100)
return self.load_data(self.rescue_item_idx)
sample = {
"text": text,
@ -238,7 +262,13 @@ class TTSDataset(Dataset):
for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item
sequence = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
text_to_sequence(
text,
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
add_blank=self.add_blank,
),
dtype=np.int32,
)
self.items[idx][0] = sequence
@ -249,6 +279,7 @@ class TTSDataset(Dataset):
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
]
@ -347,6 +378,14 @@ class TTSDataset(Dataset):
mel_lengths = [m.shape[1] for m in mel]
# lengths adjusted by the reduction factor
mel_lengths_adjusted = [
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
if m.shape[1] % self.outputs_per_step
else m.shape[1]
for m in mel
]
# compute 'stop token' targets
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
@ -385,6 +424,20 @@ class TTSDataset(Dataset):
else:
linear = None
# format waveforms
wav_padded = None
if self.return_wav:
wav_lengths = [w.shape[0] for w in wav]
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
wav_lengths = torch.LongTensor(wav_lengths)
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
for i, w in enumerate(wav):
mel_length = mel_lengths_adjusted[i]
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.ap.hop_length]
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2)
# collate attention alignments
if batch[0]["attn"] is not None:
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
@ -409,6 +462,7 @@ class TTSDataset(Dataset):
d_vectors,
speaker_ids,
attns,
wav_padded,
)
raise TypeError(

View File

@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
import tensorflow as tf
def text_to_seq(text, CONFIG):
def text_to_seq(text, CONFIG, custom_symbols=None):
text_cleaner = [CONFIG.text_cleaner]
# text ot phonemes to sequence vector
if CONFIG.use_phonemes:
@ -28,16 +28,14 @@ def text_to_seq(text, CONFIG):
tp=CONFIG.characters,
add_blank=CONFIG.add_blank,
use_espeak_phonemes=CONFIG.use_espeak_phonemes,
custom_symbols=custom_symbols,
),
dtype=np.int32,
)
else:
seq = np.asarray(
text_to_sequence(
text,
text_cleaner,
tp=CONFIG.characters,
add_blank=CONFIG.add_blank,
text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols
),
dtype=np.int32,
)
@ -229,13 +227,16 @@ def synthesis(
"""
# GST processing
style_mel = None
custom_symbols = None
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else:
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
if hasattr(model, "make_symbols"):
custom_symbols = model.make_symbols(CONFIG)
# preprocess the given text
text_inputs = text_to_seq(text, CONFIG)
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
# pass tensors to backend
if backend == "torch":
if speaker_id is not None:
@ -274,15 +275,18 @@ def synthesis(
# convert outputs to numpy
# plot results
wav = None
if use_griffin_lim:
wav = inv_spectrogram(model_outputs, ap, CONFIG)
# trim silence
if do_trim_silence:
wav = trim_silence(wav, ap)
if hasattr(model, "END2END") and model.END2END:
wav = model_outputs.squeeze(0)
else:
if use_griffin_lim:
wav = inv_spectrogram(model_outputs, ap, CONFIG)
# trim silence
if do_trim_silence:
wav = trim_silence(wav, ap)
return_dict = {
"wav": wav,
"alignments": alignments,
"model_outputs": model_outputs,
"text_inputs": text_inputs,
"outputs": outputs,
}
return return_dict

View File

@ -2,10 +2,9 @@
# adapted from https://github.com/keithito/tacotron
import re
import unicodedata
from typing import Dict, List
import gruut
from packaging import version
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
@ -22,6 +21,7 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
_symbols = symbols
_phonemes = phonemes
# Regular expression matching text enclosed in curly braces:
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
@ -81,7 +81,7 @@ def text2phone(text, language, use_espeak_phonemes=False):
# Fix a few phonemes
ph = ph.translate(GRUUT_TRANS_TABLE)
print(" > Phonemes: {}".format(ph))
# print(" > Phonemes: {}".format(ph))
return ph
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
@ -106,13 +106,38 @@ def pad_with_eos_bos(phoneme_sequence, tp=None):
def phoneme_to_sequence(
text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False, use_espeak_phonemes=False
):
text: str,
cleaner_names: List[str],
language: str,
enable_eos_bos: bool = False,
custom_symbols: List[str] = None,
tp: Dict = None,
add_blank: bool = False,
use_espeak_phonemes: bool = False,
) -> List[int]:
"""Converts a string of phonemes to a sequence of IDs.
Args:
text (str): string to convert to a sequence
cleaner_names (List[str]): names of the cleaner functions to run the text through
language (str): text language key for phonemization.
enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens.
tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc
Returns:
List[int]: List of integers corresponding to the symbols in the text
"""
# pylint: disable=global-statement
global _phonemes_to_id, _phonemes
if tp:
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
elif custom_symbols is not None:
_phonemes = custom_symbols
_phonemes_to_id = {s: i for i, s in enumerate(custom_symbols)}
sequence = []
clean_text = _clean_text(text, cleaner_names)
@ -127,7 +152,6 @@ def phoneme_to_sequence(
sequence = pad_with_eos_bos(sequence, tp=tp)
if add_blank:
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
return sequence
@ -149,27 +173,31 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False):
return result.replace("}{", " ")
def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
def text_to_sequence(
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
) -> List[int]:
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
tp: dictionary of character parameters to use a custom character set.
text (str): string to convert to a sequence
cleaner_names (List[str]): names of the cleaner functions to run the text through
tp (Dict): dictionary of character parameters to use a custom character set.
add_blank (bool): option to add a blank token between each token.
Returns:
List of integers corresponding to the symbols in the text
List[int]: List of integers corresponding to the symbols in the text
"""
# pylint: disable=global-statement
global _symbol_to_id, _symbols
if tp:
_symbols, _ = make_symbols(**tp)
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
elif custom_symbols is not None:
_symbols = custom_symbols
_symbol_to_id = {s: i for i, s in enumerate(custom_symbols)}
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while text:
m = _CURLY_RE.match(text)

View File

@ -42,6 +42,7 @@ class TestTTSDataset(unittest.TestCase):
r,
c.text_cleaner,
compute_linear_spec=True,
return_wav=True,
ap=self.ap,
meta_data=items,
characters=c.characters,
@ -75,16 +76,26 @@ class TestTTSDataset(unittest.TestCase):
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
wavs = data[11]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert isinstance(speaker_name[0], str)
assert linear_input.shape[0] == c.batch_size
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio["num_mels"]
assert (
wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length
), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}"
# make sure that the computed mels and the waveform match and correctly computed
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg]
assert abs(mel_diff.sum()) < 1e-5
# check normalization ranges
if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm

View File

@ -27,6 +27,7 @@ config = AlignTTSConfig(
"Be a voice, not an echo.",
],
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)