mirror of https://github.com/coqui-ai/TTS.git
Docstring edit in `TTSDataset.py` ✍️
parent
cfa5041db7
commit
932ab107ae
|
@ -2,6 +2,7 @@ import collections
|
|||
import os
|
||||
import random
|
||||
from multiprocessing import Pool
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -10,52 +11,82 @@ from torch.utils.data import Dataset
|
|||
|
||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
compute_linear_spec,
|
||||
ap,
|
||||
meta_data,
|
||||
tp=None,
|
||||
add_blank=False,
|
||||
batch_group_size=0,
|
||||
min_seq_len=0,
|
||||
max_seq_len=float("inf"),
|
||||
use_phonemes=False,
|
||||
phoneme_cache_path=None,
|
||||
phoneme_language="en-us",
|
||||
enable_eos_bos=False,
|
||||
speaker_id_mapping=None,
|
||||
d_vector_mapping=None,
|
||||
use_noise_augment=False,
|
||||
verbose=False,
|
||||
outputs_per_step: int,
|
||||
text_cleaner: list,
|
||||
compute_linear_spec: bool,
|
||||
ap: AudioProcessor,
|
||||
meta_data: List[List],
|
||||
characters: Dict = None,
|
||||
add_blank: bool = False,
|
||||
batch_group_size: int = 0,
|
||||
min_seq_len: int = 0,
|
||||
max_seq_len: int = float("inf"),
|
||||
use_phonemes: bool = False,
|
||||
phoneme_cache_path: str = None,
|
||||
phoneme_language: str = "en-us",
|
||||
enable_eos_bos: bool = False,
|
||||
speaker_id_mapping: Dict = None,
|
||||
d_vector_mapping: Dict = None,
|
||||
use_noise_augment: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
||||
|
||||
If you need something different, you can either override or create a new class as the dataset is
|
||||
initialized by the model.
|
||||
|
||||
Args:
|
||||
outputs_per_step (int): number of time frames predicted per step.
|
||||
text_cleaner (str): text cleaner used for the dataset.
|
||||
outputs_per_step (int): Number of time frames predicted per step.
|
||||
|
||||
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
|
||||
|
||||
compute_linear_spec (bool): compute linear spectrogram if True.
|
||||
ap (TTS.tts.utils.AudioProcessor): audio processor object.
|
||||
meta_data (list): list of dataset instances.
|
||||
tp (dict): dict of custom text characters used for converting texts to sequences.
|
||||
batch_group_size (int): (0) range of batch randomization after sorting
|
||||
sequences by length.
|
||||
min_seq_len (int): (0) minimum sequence length to be processed
|
||||
by the loader.
|
||||
max_seq_len (int): (float("inf")) maximum sequence length.
|
||||
use_phonemes (bool): (true) if true, text converted to phonemes.
|
||||
phoneme_cache_path (str): path to cache phoneme features.
|
||||
phoneme_language (str): one the languages from
|
||||
https://github.com/bootphon/phonemizer#languages
|
||||
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
|
||||
speaker_id_mapping (dict): list of speaker ids to map speaker names to numerical ids.
|
||||
d_vector_mapping (dict): dictionary of d-vectors that maps each audio file to a pre-computed d-vector.
|
||||
use_noise_augment (bool): enable adding random noise to wav for augmentation.
|
||||
verbose (bool): print diagnostic information.
|
||||
|
||||
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
|
||||
|
||||
meta_data (list): List of dataset instances.
|
||||
|
||||
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||
|
||||
add_blank (bool): Add a special `blank` character after every other character. It helps some
|
||||
models achieve better results. Defaults to false.
|
||||
|
||||
batch_group_size (int): Range of batch randomization after sorting
|
||||
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
||||
batch. Set 0 to disable. Defaults to 0.
|
||||
|
||||
min_seq_len (int): Minimum input sequence length to be processed
|
||||
by the loader. Filter out input sequences that are shorter than this. Some models have a
|
||||
minimum input length due to its architecture. Defaults to 0.
|
||||
|
||||
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
|
||||
It helps for controlling the VRAM usage against long input sequences. Especially models with
|
||||
RNN layers are sensitive to input length. Defaults to `Inf`.
|
||||
|
||||
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
|
||||
|
||||
phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in
|
||||
the coming iterations. Defaults to None.
|
||||
|
||||
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
|
||||
|
||||
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
|
||||
to False.
|
||||
|
||||
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
|
||||
embedding layer. Defaults to None.
|
||||
|
||||
d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None.
|
||||
|
||||
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
|
||||
|
||||
verbose (bool): Print diagnostic information. Defaults to false.
|
||||
"""
|
||||
super().__init__()
|
||||
self.batch_group_size = batch_group_size
|
||||
|
@ -67,7 +98,7 @@ class TTSDataset(Dataset):
|
|||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
self.tp = tp
|
||||
self.characters = characters
|
||||
self.add_blank = add_blank
|
||||
self.use_phonemes = use_phonemes
|
||||
self.phoneme_cache_path = phoneme_cache_path
|
||||
|
@ -97,13 +128,13 @@ class TTSDataset(Dataset):
|
|||
return data
|
||||
|
||||
@staticmethod
|
||||
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, tp, add_blank):
|
||||
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank):
|
||||
"""generate a phoneme sequence from text.
|
||||
since the usage is for subsequent caching, we never add bos and
|
||||
eos chars here. Instead we add those dynamically later; based on the
|
||||
config option."""
|
||||
phonemes = phoneme_to_sequence(
|
||||
text, [cleaners], language=language, enable_eos_bos=False, tp=tp, add_blank=add_blank
|
||||
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank
|
||||
)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
np.save(cache_path, phonemes)
|
||||
|
@ -111,7 +142,7 @@ class TTSDataset(Dataset):
|
|||
|
||||
@staticmethod
|
||||
def _load_or_generate_phoneme_sequence(
|
||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, tp, add_blank
|
||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank
|
||||
):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
|
||||
|
@ -122,15 +153,15 @@ class TTSDataset(Dataset):
|
|||
phonemes = np.load(cache_path)
|
||||
except FileNotFoundError:
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, tp, add_blank
|
||||
text, cache_path, cleaners, language, characters, add_blank
|
||||
)
|
||||
except (ValueError, IOError):
|
||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, tp, add_blank
|
||||
text, cache_path, cleaners, language, characters, add_blank
|
||||
)
|
||||
if enable_eos_bos:
|
||||
phonemes = pad_with_eos_bos(phonemes, tp=tp)
|
||||
phonemes = pad_with_eos_bos(phonemes, tp=characters)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
return phonemes
|
||||
|
||||
|
@ -158,13 +189,14 @@ class TTSDataset(Dataset):
|
|||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.tp,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
)
|
||||
|
||||
else:
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
|
||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
||||
dtype=np.int32,
|
||||
)
|
||||
|
||||
assert text.size > 0, self.items[idx][1]
|
||||
|
@ -206,7 +238,8 @@ class TTSDataset(Dataset):
|
|||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
text, *_ = item
|
||||
sequence = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
|
||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.items[idx][0] = sequence
|
||||
|
||||
|
@ -216,7 +249,7 @@ class TTSDataset(Dataset):
|
|||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.tp,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
]
|
||||
if self.verbose:
|
||||
|
|
Loading…
Reference in New Issue