From 932ab107ae571af612fa8d7f406ae13017dd4d58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 21 Jun 2021 16:53:19 +0200 Subject: [PATCH] =?UTF-8?q?Docstring=20edit=20in=20`TTSDataset.py`=20?= =?UTF-8?q?=E2=9C=8D=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TTS/tts/datasets/TTSDataset.py | 131 +++++++++++++++++++++------------ 1 file changed, 82 insertions(+), 49 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index d0fbb553..0fc23231 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -2,6 +2,7 @@ import collections import os import random from multiprocessing import Pool +from typing import Dict, List import numpy as np import torch @@ -10,52 +11,82 @@ from torch.utils.data import Dataset from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence +from TTS.utils.audio import AudioProcessor class TTSDataset(Dataset): def __init__( self, - outputs_per_step, - text_cleaner, - compute_linear_spec, - ap, - meta_data, - tp=None, - add_blank=False, - batch_group_size=0, - min_seq_len=0, - max_seq_len=float("inf"), - use_phonemes=False, - phoneme_cache_path=None, - phoneme_language="en-us", - enable_eos_bos=False, - speaker_id_mapping=None, - d_vector_mapping=None, - use_noise_augment=False, - verbose=False, + outputs_per_step: int, + text_cleaner: list, + compute_linear_spec: bool, + ap: AudioProcessor, + meta_data: List[List], + characters: Dict = None, + add_blank: bool = False, + batch_group_size: int = 0, + min_seq_len: int = 0, + max_seq_len: int = float("inf"), + use_phonemes: bool = False, + phoneme_cache_path: str = None, + phoneme_language: str = "en-us", + enable_eos_bos: bool = False, + speaker_id_mapping: Dict = None, + d_vector_mapping: Dict = None, + use_noise_augment: bool = False, + verbose: bool = False, ): - """ + """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. + + If you need something different, you can either override or create a new class as the dataset is + initialized by the model. + Args: - outputs_per_step (int): number of time frames predicted per step. - text_cleaner (str): text cleaner used for the dataset. + outputs_per_step (int): Number of time frames predicted per step. + + text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs. + compute_linear_spec (bool): compute linear spectrogram if True. - ap (TTS.tts.utils.AudioProcessor): audio processor object. - meta_data (list): list of dataset instances. - tp (dict): dict of custom text characters used for converting texts to sequences. - batch_group_size (int): (0) range of batch randomization after sorting - sequences by length. - min_seq_len (int): (0) minimum sequence length to be processed - by the loader. - max_seq_len (int): (float("inf")) maximum sequence length. - use_phonemes (bool): (true) if true, text converted to phonemes. - phoneme_cache_path (str): path to cache phoneme features. - phoneme_language (str): one the languages from - https://github.com/bootphon/phonemizer#languages - enable_eos_bos (bool): enable end of sentence and beginning of sentences characters. - speaker_id_mapping (dict): list of speaker ids to map speaker names to numerical ids. - d_vector_mapping (dict): dictionary of d-vectors that maps each audio file to a pre-computed d-vector. - use_noise_augment (bool): enable adding random noise to wav for augmentation. - verbose (bool): print diagnostic information. + + ap (TTS.tts.utils.AudioProcessor): Audio processor object. + + meta_data (list): List of dataset instances. + + characters (dict): `dict` of custom text characters used for converting texts to sequences. + + add_blank (bool): Add a special `blank` character after every other character. It helps some + models achieve better results. Defaults to false. + + batch_group_size (int): Range of batch randomization after sorting + sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a + batch. Set 0 to disable. Defaults to 0. + + min_seq_len (int): Minimum input sequence length to be processed + by the loader. Filter out input sequences that are shorter than this. Some models have a + minimum input length due to its architecture. Defaults to 0. + + max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this. + It helps for controlling the VRAM usage against long input sequences. Especially models with + RNN layers are sensitive to input length. Defaults to `Inf`. + + use_phonemes (bool): If true, input text converted to phonemes. Defaults to false. + + phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in + the coming iterations. Defaults to None. + + phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`. + + enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults + to False. + + speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the + embedding layer. Defaults to None. + + d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None. + + use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. + + verbose (bool): Print diagnostic information. Defaults to false. """ super().__init__() self.batch_group_size = batch_group_size @@ -67,7 +98,7 @@ class TTSDataset(Dataset): self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap - self.tp = tp + self.characters = characters self.add_blank = add_blank self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path @@ -97,13 +128,13 @@ class TTSDataset(Dataset): return data @staticmethod - def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, tp, add_blank): + def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank): """generate a phoneme sequence from text. since the usage is for subsequent caching, we never add bos and eos chars here. Instead we add those dynamically later; based on the config option.""" phonemes = phoneme_to_sequence( - text, [cleaners], language=language, enable_eos_bos=False, tp=tp, add_blank=add_blank + text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank ) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) @@ -111,7 +142,7 @@ class TTSDataset(Dataset): @staticmethod def _load_or_generate_phoneme_sequence( - wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, tp, add_blank + wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank ): file_name = os.path.splitext(os.path.basename(wav_file))[0] @@ -122,15 +153,15 @@ class TTSDataset(Dataset): phonemes = np.load(cache_path) except FileNotFoundError: phonemes = TTSDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, tp, add_blank + text, cache_path, cleaners, language, characters, add_blank ) except (ValueError, IOError): print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) phonemes = TTSDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, tp, add_blank + text, cache_path, cleaners, language, characters, add_blank ) if enable_eos_bos: - phonemes = pad_with_eos_bos(phonemes, tp=tp) + phonemes = pad_with_eos_bos(phonemes, tp=characters) phonemes = np.asarray(phonemes, dtype=np.int32) return phonemes @@ -158,13 +189,14 @@ class TTSDataset(Dataset): self.enable_eos_bos, self.cleaners, self.phoneme_language, - self.tp, + self.characters, self.add_blank, ) else: text = np.asarray( - text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32 + text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank), + dtype=np.int32, ) assert text.size > 0, self.items[idx][1] @@ -206,7 +238,8 @@ class TTSDataset(Dataset): for idx, item in enumerate(tqdm.tqdm(self.items)): text, *_ = item sequence = np.asarray( - text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32 + text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank), + dtype=np.int32, ) self.items[idx][0] = sequence @@ -216,7 +249,7 @@ class TTSDataset(Dataset): self.enable_eos_bos, self.cleaners, self.phoneme_language, - self.tp, + self.characters, self.add_blank, ] if self.verbose: