mirror of https://github.com/coqui-ai/TTS.git
Docstring edit in `TTSDataset.py` ✍️
parent
cfa5041db7
commit
932ab107ae
|
@ -2,6 +2,7 @@ import collections
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -10,52 +11,82 @@ from torch.utils.data import Dataset
|
||||||
|
|
||||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||||
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
class TTSDataset(Dataset):
|
class TTSDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
outputs_per_step,
|
outputs_per_step: int,
|
||||||
text_cleaner,
|
text_cleaner: list,
|
||||||
compute_linear_spec,
|
compute_linear_spec: bool,
|
||||||
ap,
|
ap: AudioProcessor,
|
||||||
meta_data,
|
meta_data: List[List],
|
||||||
tp=None,
|
characters: Dict = None,
|
||||||
add_blank=False,
|
add_blank: bool = False,
|
||||||
batch_group_size=0,
|
batch_group_size: int = 0,
|
||||||
min_seq_len=0,
|
min_seq_len: int = 0,
|
||||||
max_seq_len=float("inf"),
|
max_seq_len: int = float("inf"),
|
||||||
use_phonemes=False,
|
use_phonemes: bool = False,
|
||||||
phoneme_cache_path=None,
|
phoneme_cache_path: str = None,
|
||||||
phoneme_language="en-us",
|
phoneme_language: str = "en-us",
|
||||||
enable_eos_bos=False,
|
enable_eos_bos: bool = False,
|
||||||
speaker_id_mapping=None,
|
speaker_id_mapping: Dict = None,
|
||||||
d_vector_mapping=None,
|
d_vector_mapping: Dict = None,
|
||||||
use_noise_augment=False,
|
use_noise_augment: bool = False,
|
||||||
verbose=False,
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
||||||
|
|
||||||
|
If you need something different, you can either override or create a new class as the dataset is
|
||||||
|
initialized by the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outputs_per_step (int): number of time frames predicted per step.
|
outputs_per_step (int): Number of time frames predicted per step.
|
||||||
text_cleaner (str): text cleaner used for the dataset.
|
|
||||||
|
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
|
||||||
|
|
||||||
compute_linear_spec (bool): compute linear spectrogram if True.
|
compute_linear_spec (bool): compute linear spectrogram if True.
|
||||||
ap (TTS.tts.utils.AudioProcessor): audio processor object.
|
|
||||||
meta_data (list): list of dataset instances.
|
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
|
||||||
tp (dict): dict of custom text characters used for converting texts to sequences.
|
|
||||||
batch_group_size (int): (0) range of batch randomization after sorting
|
meta_data (list): List of dataset instances.
|
||||||
sequences by length.
|
|
||||||
min_seq_len (int): (0) minimum sequence length to be processed
|
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||||
by the loader.
|
|
||||||
max_seq_len (int): (float("inf")) maximum sequence length.
|
add_blank (bool): Add a special `blank` character after every other character. It helps some
|
||||||
use_phonemes (bool): (true) if true, text converted to phonemes.
|
models achieve better results. Defaults to false.
|
||||||
phoneme_cache_path (str): path to cache phoneme features.
|
|
||||||
phoneme_language (str): one the languages from
|
batch_group_size (int): Range of batch randomization after sorting
|
||||||
https://github.com/bootphon/phonemizer#languages
|
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
||||||
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
|
batch. Set 0 to disable. Defaults to 0.
|
||||||
speaker_id_mapping (dict): list of speaker ids to map speaker names to numerical ids.
|
|
||||||
d_vector_mapping (dict): dictionary of d-vectors that maps each audio file to a pre-computed d-vector.
|
min_seq_len (int): Minimum input sequence length to be processed
|
||||||
use_noise_augment (bool): enable adding random noise to wav for augmentation.
|
by the loader. Filter out input sequences that are shorter than this. Some models have a
|
||||||
verbose (bool): print diagnostic information.
|
minimum input length due to its architecture. Defaults to 0.
|
||||||
|
|
||||||
|
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
|
||||||
|
It helps for controlling the VRAM usage against long input sequences. Especially models with
|
||||||
|
RNN layers are sensitive to input length. Defaults to `Inf`.
|
||||||
|
|
||||||
|
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
|
||||||
|
|
||||||
|
phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in
|
||||||
|
the coming iterations. Defaults to None.
|
||||||
|
|
||||||
|
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
|
||||||
|
|
||||||
|
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
|
||||||
|
to False.
|
||||||
|
|
||||||
|
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
|
||||||
|
embedding layer. Defaults to None.
|
||||||
|
|
||||||
|
d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None.
|
||||||
|
|
||||||
|
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
|
||||||
|
|
||||||
|
verbose (bool): Print diagnostic information. Defaults to false.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_group_size = batch_group_size
|
self.batch_group_size = batch_group_size
|
||||||
|
@ -67,7 +98,7 @@ class TTSDataset(Dataset):
|
||||||
self.min_seq_len = min_seq_len
|
self.min_seq_len = min_seq_len
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.tp = tp
|
self.characters = characters
|
||||||
self.add_blank = add_blank
|
self.add_blank = add_blank
|
||||||
self.use_phonemes = use_phonemes
|
self.use_phonemes = use_phonemes
|
||||||
self.phoneme_cache_path = phoneme_cache_path
|
self.phoneme_cache_path = phoneme_cache_path
|
||||||
|
@ -97,13 +128,13 @@ class TTSDataset(Dataset):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, tp, add_blank):
|
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank):
|
||||||
"""generate a phoneme sequence from text.
|
"""generate a phoneme sequence from text.
|
||||||
since the usage is for subsequent caching, we never add bos and
|
since the usage is for subsequent caching, we never add bos and
|
||||||
eos chars here. Instead we add those dynamically later; based on the
|
eos chars here. Instead we add those dynamically later; based on the
|
||||||
config option."""
|
config option."""
|
||||||
phonemes = phoneme_to_sequence(
|
phonemes = phoneme_to_sequence(
|
||||||
text, [cleaners], language=language, enable_eos_bos=False, tp=tp, add_blank=add_blank
|
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank
|
||||||
)
|
)
|
||||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||||
np.save(cache_path, phonemes)
|
np.save(cache_path, phonemes)
|
||||||
|
@ -111,7 +142,7 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_or_generate_phoneme_sequence(
|
def _load_or_generate_phoneme_sequence(
|
||||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, tp, add_blank
|
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank
|
||||||
):
|
):
|
||||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||||
|
|
||||||
|
@ -122,15 +153,15 @@ class TTSDataset(Dataset):
|
||||||
phonemes = np.load(cache_path)
|
phonemes = np.load(cache_path)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||||
text, cache_path, cleaners, language, tp, add_blank
|
text, cache_path, cleaners, language, characters, add_blank
|
||||||
)
|
)
|
||||||
except (ValueError, IOError):
|
except (ValueError, IOError):
|
||||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
||||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||||
text, cache_path, cleaners, language, tp, add_blank
|
text, cache_path, cleaners, language, characters, add_blank
|
||||||
)
|
)
|
||||||
if enable_eos_bos:
|
if enable_eos_bos:
|
||||||
phonemes = pad_with_eos_bos(phonemes, tp=tp)
|
phonemes = pad_with_eos_bos(phonemes, tp=characters)
|
||||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||||
return phonemes
|
return phonemes
|
||||||
|
|
||||||
|
@ -158,13 +189,14 @@ class TTSDataset(Dataset):
|
||||||
self.enable_eos_bos,
|
self.enable_eos_bos,
|
||||||
self.cleaners,
|
self.cleaners,
|
||||||
self.phoneme_language,
|
self.phoneme_language,
|
||||||
self.tp,
|
self.characters,
|
||||||
self.add_blank,
|
self.add_blank,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
text = np.asarray(
|
text = np.asarray(
|
||||||
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
|
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
||||||
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert text.size > 0, self.items[idx][1]
|
assert text.size > 0, self.items[idx][1]
|
||||||
|
@ -206,7 +238,8 @@ class TTSDataset(Dataset):
|
||||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||||
text, *_ = item
|
text, *_ = item
|
||||||
sequence = np.asarray(
|
sequence = np.asarray(
|
||||||
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
|
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
||||||
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
self.items[idx][0] = sequence
|
self.items[idx][0] = sequence
|
||||||
|
|
||||||
|
@ -216,7 +249,7 @@ class TTSDataset(Dataset):
|
||||||
self.enable_eos_bos,
|
self.enable_eos_bos,
|
||||||
self.cleaners,
|
self.cleaners,
|
||||||
self.phoneme_language,
|
self.phoneme_language,
|
||||||
self.tp,
|
self.characters,
|
||||||
self.add_blank,
|
self.add_blank,
|
||||||
]
|
]
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
|
Loading…
Reference in New Issue