mirror of https://github.com/coqui-ai/TTS.git
Enable `custom_symbols` in text processing
Models can define their own custom symbols lists with custom `make_symbols()`pull/718/head
parent
bd4e29b4dd
commit
003e5579e8
|
@ -23,7 +23,9 @@ class TTSDataset(Dataset):
|
|||
ap: AudioProcessor,
|
||||
meta_data: List[List],
|
||||
characters: Dict = None,
|
||||
custom_symbols: List = None,
|
||||
add_blank: bool = False,
|
||||
return_wav: bool = False,
|
||||
batch_group_size: int = 0,
|
||||
min_seq_len: int = 0,
|
||||
max_seq_len: int = float("inf"),
|
||||
|
@ -54,9 +56,14 @@ class TTSDataset(Dataset):
|
|||
|
||||
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||
|
||||
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
|
||||
set of symbols need to pass it here. Defaults to `None`.
|
||||
|
||||
add_blank (bool): Add a special `blank` character after every other character. It helps some
|
||||
models achieve better results. Defaults to false.
|
||||
|
||||
return_wav (bool): Return the waveform of the sample. Defaults to False.
|
||||
|
||||
batch_group_size (int): Range of batch randomization after sorting
|
||||
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
||||
batch. Set 0 to disable. Defaults to 0.
|
||||
|
@ -95,10 +102,12 @@ class TTSDataset(Dataset):
|
|||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.compute_linear_spec = compute_linear_spec
|
||||
self.return_wav = return_wav
|
||||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
self.characters = characters
|
||||
self.custom_symbols = custom_symbols
|
||||
self.add_blank = add_blank
|
||||
self.use_phonemes = use_phonemes
|
||||
self.phoneme_cache_path = phoneme_cache_path
|
||||
|
@ -109,6 +118,7 @@ class TTSDataset(Dataset):
|
|||
self.use_noise_augment = use_noise_augment
|
||||
self.verbose = verbose
|
||||
self.input_seq_computed = False
|
||||
self.rescue_item_idx = 1
|
||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||
if self.verbose:
|
||||
|
@ -128,13 +138,21 @@ class TTSDataset(Dataset):
|
|||
return data
|
||||
|
||||
@staticmethod
|
||||
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank):
|
||||
def _generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
):
|
||||
"""generate a phoneme sequence from text.
|
||||
since the usage is for subsequent caching, we never add bos and
|
||||
eos chars here. Instead we add those dynamically later; based on the
|
||||
config option."""
|
||||
phonemes = phoneme_to_sequence(
|
||||
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank
|
||||
text,
|
||||
[cleaners],
|
||||
language=language,
|
||||
enable_eos_bos=False,
|
||||
custom_symbols=custom_symbols,
|
||||
tp=characters,
|
||||
add_blank=add_blank,
|
||||
)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
np.save(cache_path, phonemes)
|
||||
|
@ -142,7 +160,7 @@ class TTSDataset(Dataset):
|
|||
|
||||
@staticmethod
|
||||
def _load_or_generate_phoneme_sequence(
|
||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank
|
||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
|
||||
):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
|
||||
|
@ -153,12 +171,12 @@ class TTSDataset(Dataset):
|
|||
phonemes = np.load(cache_path)
|
||||
except FileNotFoundError:
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, characters, add_blank
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
)
|
||||
except (ValueError, IOError):
|
||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, characters, add_blank
|
||||
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||
)
|
||||
if enable_eos_bos:
|
||||
phonemes = pad_with_eos_bos(phonemes, tp=characters)
|
||||
|
@ -189,13 +207,19 @@ class TTSDataset(Dataset):
|
|||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.custom_symbols,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
)
|
||||
|
||||
else:
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
||||
text_to_sequence(
|
||||
text,
|
||||
[self.cleaners],
|
||||
custom_symbols=self.custom_symbols,
|
||||
tp=self.characters,
|
||||
add_blank=self.add_blank,
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
|
||||
|
@ -209,7 +233,7 @@ class TTSDataset(Dataset):
|
|||
# return a different sample if the phonemized
|
||||
# text is longer than the threshold
|
||||
# TODO: find a better fix
|
||||
return self.load_data(100)
|
||||
return self.load_data(self.rescue_item_idx)
|
||||
|
||||
sample = {
|
||||
"text": text,
|
||||
|
@ -238,7 +262,13 @@ class TTSDataset(Dataset):
|
|||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
text, *_ = item
|
||||
sequence = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
||||
text_to_sequence(
|
||||
text,
|
||||
[self.cleaners],
|
||||
custom_symbols=self.custom_symbols,
|
||||
tp=self.characters,
|
||||
add_blank=self.add_blank,
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.items[idx][0] = sequence
|
||||
|
@ -249,6 +279,7 @@ class TTSDataset(Dataset):
|
|||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.custom_symbols,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
]
|
||||
|
@ -347,6 +378,14 @@ class TTSDataset(Dataset):
|
|||
|
||||
mel_lengths = [m.shape[1] for m in mel]
|
||||
|
||||
# lengths adjusted by the reduction factor
|
||||
mel_lengths_adjusted = [
|
||||
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
|
||||
if m.shape[1] % self.outputs_per_step
|
||||
else m.shape[1]
|
||||
for m in mel
|
||||
]
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
|
||||
|
||||
|
@ -385,6 +424,20 @@ class TTSDataset(Dataset):
|
|||
else:
|
||||
linear = None
|
||||
|
||||
# format waveforms
|
||||
wav_padded = None
|
||||
if self.return_wav:
|
||||
wav_lengths = [w.shape[0] for w in wav]
|
||||
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
||||
wav_lengths = torch.LongTensor(wav_lengths)
|
||||
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
|
||||
for i, w in enumerate(wav):
|
||||
mel_length = mel_lengths_adjusted[i]
|
||||
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
||||
w = w[: mel_length * self.ap.hop_length]
|
||||
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
||||
wav_padded.transpose_(1, 2)
|
||||
|
||||
# collate attention alignments
|
||||
if batch[0]["attn"] is not None:
|
||||
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
|
||||
|
@ -409,6 +462,7 @@ class TTSDataset(Dataset):
|
|||
d_vectors,
|
||||
speaker_ids,
|
||||
attns,
|
||||
wav_padded,
|
||||
)
|
||||
|
||||
raise TypeError(
|
||||
|
|
|
@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def text_to_seq(text, CONFIG):
|
||||
def text_to_seq(text, CONFIG, custom_symbols=None):
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
# text ot phonemes to sequence vector
|
||||
if CONFIG.use_phonemes:
|
||||
|
@ -28,16 +28,14 @@ def text_to_seq(text, CONFIG):
|
|||
tp=CONFIG.characters,
|
||||
add_blank=CONFIG.add_blank,
|
||||
use_espeak_phonemes=CONFIG.use_espeak_phonemes,
|
||||
custom_symbols=custom_symbols,
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
else:
|
||||
seq = np.asarray(
|
||||
text_to_sequence(
|
||||
text,
|
||||
text_cleaner,
|
||||
tp=CONFIG.characters,
|
||||
add_blank=CONFIG.add_blank,
|
||||
text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
|
@ -229,13 +227,16 @@ def synthesis(
|
|||
"""
|
||||
# GST processing
|
||||
style_mel = None
|
||||
custom_symbols = None
|
||||
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
|
||||
if isinstance(style_wav, dict):
|
||||
style_mel = style_wav
|
||||
else:
|
||||
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
||||
if hasattr(model, "make_symbols"):
|
||||
custom_symbols = model.make_symbols(CONFIG)
|
||||
# preprocess the given text
|
||||
text_inputs = text_to_seq(text, CONFIG)
|
||||
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
|
||||
# pass tensors to backend
|
||||
if backend == "torch":
|
||||
if speaker_id is not None:
|
||||
|
@ -274,15 +275,18 @@ def synthesis(
|
|||
# convert outputs to numpy
|
||||
# plot results
|
||||
wav = None
|
||||
if use_griffin_lim:
|
||||
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
||||
# trim silence
|
||||
if do_trim_silence:
|
||||
wav = trim_silence(wav, ap)
|
||||
if hasattr(model, "END2END") and model.END2END:
|
||||
wav = model_outputs.squeeze(0)
|
||||
else:
|
||||
if use_griffin_lim:
|
||||
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
||||
# trim silence
|
||||
if do_trim_silence:
|
||||
wav = trim_silence(wav, ap)
|
||||
return_dict = {
|
||||
"wav": wav,
|
||||
"alignments": alignments,
|
||||
"model_outputs": model_outputs,
|
||||
"text_inputs": text_inputs,
|
||||
"outputs": outputs,
|
||||
}
|
||||
return return_dict
|
||||
|
|
|
@ -2,10 +2,9 @@
|
|||
# adapted from https://github.com/keithito/tacotron
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
from typing import Dict, List
|
||||
|
||||
import gruut
|
||||
from packaging import version
|
||||
|
||||
from TTS.tts.utils.text import cleaners
|
||||
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
|
||||
|
@ -22,6 +21,7 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
|
|||
|
||||
_symbols = symbols
|
||||
_phonemes = phonemes
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
||||
|
||||
|
@ -81,7 +81,7 @@ def text2phone(text, language, use_espeak_phonemes=False):
|
|||
# Fix a few phonemes
|
||||
ph = ph.translate(GRUUT_TRANS_TABLE)
|
||||
|
||||
print(" > Phonemes: {}".format(ph))
|
||||
# print(" > Phonemes: {}".format(ph))
|
||||
return ph
|
||||
|
||||
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
|
||||
|
@ -106,13 +106,38 @@ def pad_with_eos_bos(phoneme_sequence, tp=None):
|
|||
|
||||
|
||||
def phoneme_to_sequence(
|
||||
text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False, use_espeak_phonemes=False
|
||||
):
|
||||
text: str,
|
||||
cleaner_names: List[str],
|
||||
language: str,
|
||||
enable_eos_bos: bool = False,
|
||||
custom_symbols: List[str] = None,
|
||||
tp: Dict = None,
|
||||
add_blank: bool = False,
|
||||
use_espeak_phonemes: bool = False,
|
||||
) -> List[int]:
|
||||
"""Converts a string of phonemes to a sequence of IDs.
|
||||
|
||||
Args:
|
||||
text (str): string to convert to a sequence
|
||||
cleaner_names (List[str]): names of the cleaner functions to run the text through
|
||||
language (str): text language key for phonemization.
|
||||
enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens.
|
||||
tp (Dict): dictionary of character parameters to use a custom character set.
|
||||
add_blank (bool): option to add a blank token between each token.
|
||||
use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc
|
||||
|
||||
Returns:
|
||||
List[int]: List of integers corresponding to the symbols in the text
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
global _phonemes_to_id, _phonemes
|
||||
|
||||
if tp:
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||
elif custom_symbols is not None:
|
||||
_phonemes = custom_symbols
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(custom_symbols)}
|
||||
|
||||
sequence = []
|
||||
clean_text = _clean_text(text, cleaner_names)
|
||||
|
@ -127,7 +152,6 @@ def phoneme_to_sequence(
|
|||
sequence = pad_with_eos_bos(sequence, tp=tp)
|
||||
if add_blank:
|
||||
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
|
@ -149,27 +173,31 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False):
|
|||
return result.replace("}{", " ")
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
|
||||
def text_to_sequence(
|
||||
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
|
||||
) -> List[int]:
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
tp: dictionary of character parameters to use a custom character set.
|
||||
text (str): string to convert to a sequence
|
||||
cleaner_names (List[str]): names of the cleaner functions to run the text through
|
||||
tp (Dict): dictionary of character parameters to use a custom character set.
|
||||
add_blank (bool): option to add a blank token between each token.
|
||||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
List[int]: List of integers corresponding to the symbols in the text
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
global _symbol_to_id, _symbols
|
||||
if tp:
|
||||
_symbols, _ = make_symbols(**tp)
|
||||
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
||||
elif custom_symbols is not None:
|
||||
_symbols = custom_symbols
|
||||
_symbol_to_id = {s: i for i, s in enumerate(custom_symbols)}
|
||||
|
||||
sequence = []
|
||||
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while text:
|
||||
m = _CURLY_RE.match(text)
|
||||
|
|
|
@ -42,6 +42,7 @@ class TestTTSDataset(unittest.TestCase):
|
|||
r,
|
||||
c.text_cleaner,
|
||||
compute_linear_spec=True,
|
||||
return_wav=True,
|
||||
ap=self.ap,
|
||||
meta_data=items,
|
||||
characters=c.characters,
|
||||
|
@ -75,16 +76,26 @@ class TestTTSDataset(unittest.TestCase):
|
|||
mel_lengths = data[5]
|
||||
stop_target = data[6]
|
||||
item_idx = data[7]
|
||||
wavs = data[11]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert isinstance(speaker_name[0], str)
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.audio["num_mels"]
|
||||
assert (
|
||||
wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length
|
||||
), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}"
|
||||
|
||||
# make sure that the computed mels and the waveform match and correctly computed
|
||||
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
|
||||
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
|
||||
mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg]
|
||||
assert abs(mel_diff.sum()) < 1e-5
|
||||
|
||||
# check normalization ranges
|
||||
if self.ap.symmetric_norm:
|
||||
assert mel_input.max() <= self.ap.max_norm
|
||||
|
|
|
@ -27,6 +27,7 @@ config = AlignTTSConfig(
|
|||
"Be a voice, not an echo.",
|
||||
],
|
||||
)
|
||||
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
config.save_json(config_path)
|
||||
|
|
Loading…
Reference in New Issue