Docstring edit in `TTSDataset.py` ✍️

pull/506/head
Eren Gölge 2021-06-21 16:53:19 +02:00
parent cfa5041db7
commit 932ab107ae
1 changed files with 82 additions and 49 deletions

View File

@ -2,6 +2,7 @@ import collections
import os import os
import random import random
from multiprocessing import Pool from multiprocessing import Pool
from typing import Dict, List
import numpy as np import numpy as np
import torch import torch
@ -10,52 +11,82 @@ from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
from TTS.utils.audio import AudioProcessor
class TTSDataset(Dataset): class TTSDataset(Dataset):
def __init__( def __init__(
self, self,
outputs_per_step, outputs_per_step: int,
text_cleaner, text_cleaner: list,
compute_linear_spec, compute_linear_spec: bool,
ap, ap: AudioProcessor,
meta_data, meta_data: List[List],
tp=None, characters: Dict = None,
add_blank=False, add_blank: bool = False,
batch_group_size=0, batch_group_size: int = 0,
min_seq_len=0, min_seq_len: int = 0,
max_seq_len=float("inf"), max_seq_len: int = float("inf"),
use_phonemes=False, use_phonemes: bool = False,
phoneme_cache_path=None, phoneme_cache_path: str = None,
phoneme_language="en-us", phoneme_language: str = "en-us",
enable_eos_bos=False, enable_eos_bos: bool = False,
speaker_id_mapping=None, speaker_id_mapping: Dict = None,
d_vector_mapping=None, d_vector_mapping: Dict = None,
use_noise_augment=False, use_noise_augment: bool = False,
verbose=False, verbose: bool = False,
): ):
""" """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
If you need something different, you can either override or create a new class as the dataset is
initialized by the model.
Args: Args:
outputs_per_step (int): number of time frames predicted per step. outputs_per_step (int): Number of time frames predicted per step.
text_cleaner (str): text cleaner used for the dataset.
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
compute_linear_spec (bool): compute linear spectrogram if True. compute_linear_spec (bool): compute linear spectrogram if True.
ap (TTS.tts.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances. ap (TTS.tts.utils.AudioProcessor): Audio processor object.
tp (dict): dict of custom text characters used for converting texts to sequences.
batch_group_size (int): (0) range of batch randomization after sorting meta_data (list): List of dataset instances.
sequences by length.
min_seq_len (int): (0) minimum sequence length to be processed characters (dict): `dict` of custom text characters used for converting texts to sequences.
by the loader.
max_seq_len (int): (float("inf")) maximum sequence length. add_blank (bool): Add a special `blank` character after every other character. It helps some
use_phonemes (bool): (true) if true, text converted to phonemes. models achieve better results. Defaults to false.
phoneme_cache_path (str): path to cache phoneme features.
phoneme_language (str): one the languages from batch_group_size (int): Range of batch randomization after sorting
https://github.com/bootphon/phonemizer#languages sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters. batch. Set 0 to disable. Defaults to 0.
speaker_id_mapping (dict): list of speaker ids to map speaker names to numerical ids.
d_vector_mapping (dict): dictionary of d-vectors that maps each audio file to a pre-computed d-vector. min_seq_len (int): Minimum input sequence length to be processed
use_noise_augment (bool): enable adding random noise to wav for augmentation. by the loader. Filter out input sequences that are shorter than this. Some models have a
verbose (bool): print diagnostic information. minimum input length due to its architecture. Defaults to 0.
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
It helps for controlling the VRAM usage against long input sequences. Especially models with
RNN layers are sensitive to input length. Defaults to `Inf`.
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in
the coming iterations. Defaults to None.
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
to False.
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
embedding layer. Defaults to None.
d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None.
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
verbose (bool): Print diagnostic information. Defaults to false.
""" """
super().__init__() super().__init__()
self.batch_group_size = batch_group_size self.batch_group_size = batch_group_size
@ -67,7 +98,7 @@ class TTSDataset(Dataset):
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.ap = ap self.ap = ap
self.tp = tp self.characters = characters
self.add_blank = add_blank self.add_blank = add_blank
self.use_phonemes = use_phonemes self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path self.phoneme_cache_path = phoneme_cache_path
@ -97,13 +128,13 @@ class TTSDataset(Dataset):
return data return data
@staticmethod @staticmethod
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, tp, add_blank): def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank):
"""generate a phoneme sequence from text. """generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the eos chars here. Instead we add those dynamically later; based on the
config option.""" config option."""
phonemes = phoneme_to_sequence( phonemes = phoneme_to_sequence(
text, [cleaners], language=language, enable_eos_bos=False, tp=tp, add_blank=add_blank text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank
) )
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes) np.save(cache_path, phonemes)
@ -111,7 +142,7 @@ class TTSDataset(Dataset):
@staticmethod @staticmethod
def _load_or_generate_phoneme_sequence( def _load_or_generate_phoneme_sequence(
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, tp, add_blank wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank
): ):
file_name = os.path.splitext(os.path.basename(wav_file))[0] file_name = os.path.splitext(os.path.basename(wav_file))[0]
@ -122,15 +153,15 @@ class TTSDataset(Dataset):
phonemes = np.load(cache_path) phonemes = np.load(cache_path)
except FileNotFoundError: except FileNotFoundError:
phonemes = TTSDataset._generate_and_cache_phoneme_sequence( phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank text, cache_path, cleaners, language, characters, add_blank
) )
except (ValueError, IOError): except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = TTSDataset._generate_and_cache_phoneme_sequence( phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank text, cache_path, cleaners, language, characters, add_blank
) )
if enable_eos_bos: if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=tp) phonemes = pad_with_eos_bos(phonemes, tp=characters)
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes return phonemes
@ -158,13 +189,14 @@ class TTSDataset(Dataset):
self.enable_eos_bos, self.enable_eos_bos,
self.cleaners, self.cleaners,
self.phoneme_language, self.phoneme_language,
self.tp, self.characters,
self.add_blank, self.add_blank,
) )
else: else:
text = np.asarray( text = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32 text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
dtype=np.int32,
) )
assert text.size > 0, self.items[idx][1] assert text.size > 0, self.items[idx][1]
@ -206,7 +238,8 @@ class TTSDataset(Dataset):
for idx, item in enumerate(tqdm.tqdm(self.items)): for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item text, *_ = item
sequence = np.asarray( sequence = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32 text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
dtype=np.int32,
) )
self.items[idx][0] = sequence self.items[idx][0] = sequence
@ -216,7 +249,7 @@ class TTSDataset(Dataset):
self.enable_eos_bos, self.enable_eos_bos,
self.cleaners, self.cleaners,
self.phoneme_language, self.phoneme_language,
self.tp, self.characters,
self.add_blank, self.add_blank,
] ]
if self.verbose: if self.verbose: