Docstring edit in `TTSDataset.py` ✍️

pull/506/head
Eren Gölge 2021-06-21 16:53:19 +02:00
parent cfa5041db7
commit 932ab107ae
1 changed files with 82 additions and 49 deletions

View File

@ -2,6 +2,7 @@ import collections
import os
import random
from multiprocessing import Pool
from typing import Dict, List
import numpy as np
import torch
@ -10,52 +11,82 @@ from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
from TTS.utils.audio import AudioProcessor
class TTSDataset(Dataset):
def __init__(
self,
outputs_per_step,
text_cleaner,
compute_linear_spec,
ap,
meta_data,
tp=None,
add_blank=False,
batch_group_size=0,
min_seq_len=0,
max_seq_len=float("inf"),
use_phonemes=False,
phoneme_cache_path=None,
phoneme_language="en-us",
enable_eos_bos=False,
speaker_id_mapping=None,
d_vector_mapping=None,
use_noise_augment=False,
verbose=False,
outputs_per_step: int,
text_cleaner: list,
compute_linear_spec: bool,
ap: AudioProcessor,
meta_data: List[List],
characters: Dict = None,
add_blank: bool = False,
batch_group_size: int = 0,
min_seq_len: int = 0,
max_seq_len: int = float("inf"),
use_phonemes: bool = False,
phoneme_cache_path: str = None,
phoneme_language: str = "en-us",
enable_eos_bos: bool = False,
speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None,
use_noise_augment: bool = False,
verbose: bool = False,
):
"""
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can either override or create a new class as the dataset is
initialized by the model.
Args:
outputs_per_step (int): number of time frames predicted per step.
text_cleaner (str): text cleaner used for the dataset.
outputs_per_step (int): Number of time frames predicted per step.
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
compute_linear_spec (bool): compute linear spectrogram if True.
ap (TTS.tts.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
tp (dict): dict of custom text characters used for converting texts to sequences.
batch_group_size (int): (0) range of batch randomization after sorting
sequences by length.
min_seq_len (int): (0) minimum sequence length to be processed
by the loader.
max_seq_len (int): (float("inf")) maximum sequence length.
use_phonemes (bool): (true) if true, text converted to phonemes.
phoneme_cache_path (str): path to cache phoneme features.
phoneme_language (str): one the languages from
https://github.com/bootphon/phonemizer#languages
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
speaker_id_mapping (dict): list of speaker ids to map speaker names to numerical ids.
d_vector_mapping (dict): dictionary of d-vectors that maps each audio file to a pre-computed d-vector.
use_noise_augment (bool): enable adding random noise to wav for augmentation.
verbose (bool): print diagnostic information.
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
meta_data (list): List of dataset instances.
characters (dict): `dict` of custom text characters used for converting texts to sequences.
add_blank (bool): Add a special `blank` character after every other character. It helps some
models achieve better results. Defaults to false.
batch_group_size (int): Range of batch randomization after sorting
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
batch. Set 0 to disable. Defaults to 0.
min_seq_len (int): Minimum input sequence length to be processed
by the loader. Filter out input sequences that are shorter than this. Some models have a
minimum input length due to its architecture. Defaults to 0.
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
It helps for controlling the VRAM usage against long input sequences. Especially models with
RNN layers are sensitive to input length. Defaults to `Inf`.
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in
the coming iterations. Defaults to None.
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
to False.
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
embedding layer. Defaults to None.
d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None.
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
verbose (bool): Print diagnostic information. Defaults to false.
"""
super().__init__()
self.batch_group_size = batch_group_size
@ -67,7 +98,7 @@ class TTSDataset(Dataset):
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
self.tp = tp
self.characters = characters
self.add_blank = add_blank
self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path
@ -97,13 +128,13 @@ class TTSDataset(Dataset):
return data
@staticmethod
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, tp, add_blank):
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank):
"""generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the
config option."""
phonemes = phoneme_to_sequence(
text, [cleaners], language=language, enable_eos_bos=False, tp=tp, add_blank=add_blank
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank
)
phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes)
@ -111,7 +142,7 @@ class TTSDataset(Dataset):
@staticmethod
def _load_or_generate_phoneme_sequence(
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, tp, add_blank
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank
):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
@ -122,15 +153,15 @@ class TTSDataset(Dataset):
phonemes = np.load(cache_path)
except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank
text, cache_path, cleaners, language, characters, add_blank
)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank
text, cache_path, cleaners, language, characters, add_blank
)
if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=tp)
phonemes = pad_with_eos_bos(phonemes, tp=characters)
phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes
@ -158,13 +189,14 @@ class TTSDataset(Dataset):
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.tp,
self.characters,
self.add_blank,
)
else:
text = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
dtype=np.int32,
)
assert text.size > 0, self.items[idx][1]
@ -206,7 +238,8 @@ class TTSDataset(Dataset):
for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item
sequence = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
dtype=np.int32,
)
self.items[idx][0] = sequence
@ -216,7 +249,7 @@ class TTSDataset(Dataset):
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.tp,
self.characters,
self.add_blank,
]
if self.verbose: