diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 60b514c2..337dcfa5 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -258,7 +258,7 @@ class TTSDataset(Dataset): return audio_lengths, text_lengths @staticmethod - def sort_and_filter_by_length(lengths:List[int], min_len:int, max_len:int): + def sort_and_filter_by_length(lengths: List[int], min_len: int, max_len: int): idxs = np.argsort(lengths) # ascending order ignore_idx = [] keep_idx = [] @@ -271,7 +271,7 @@ class TTSDataset(Dataset): return ignore_idx, keep_idx @staticmethod - def create_buckets(samples, batch_group_size:int): + def create_buckets(samples, batch_group_size: int): for i in range(len(samples) // batch_group_size): offset = i * batch_group_size end_offset = offset + batch_group_size @@ -286,8 +286,12 @@ class TTSDataset(Dataset): """ # sort items based on the sequence length in ascending order - text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len) - audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length(self.audio_lengths, self.min_audio_len, self.max_audio_len) + text_ignore_idx, text_keep_idx = self.sort_and_filter_by_length( + self.text_lengths, self.min_text_len, self.max_text_len + ) + audio_ignore_idx, audio_keep_idx = self.sort_and_filter_by_length( + self.audio_lengths, self.min_audio_len, self.max_audio_len + ) keep_idx = list(set(audio_keep_idx) | set(text_keep_idx)) ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index f9c44a7d..24ce51f1 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -209,7 +209,7 @@ class BaseCharacters: ), f" [!] There are duplicate characters in the character set. {set([x for x in self.vocab if self.vocab.count(x) > 1])}" def char_to_id(self, char: str) -> int: - return self._char_to_id[char] + return self._char_to_id[char] def id_to_char(self, idx: int) -> str: return self._id_to_char[idx] @@ -298,8 +298,8 @@ class IPAPhonemes(BaseCharacters): ) else: return IPAPhonemes( - **config.characters if config.characters is not None else {}, - ) + **config.characters if config.characters is not None else {}, + ) class Graphemes(BaseCharacters): diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py index b00f7f5e..0da5875e 100644 --- a/TTS/tts/utils/text/phonemizers/__init__.py +++ b/TTS/tts/utils/text/phonemizers/__init__.py @@ -52,4 +52,4 @@ def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer: if __name__ == "__main__": - print(DEF_LANG_TO_PHONEMIZER) \ No newline at end of file + print(DEF_LANG_TO_PHONEMIZER)