Fix dataset preprocessing

pull/1324/head
Eren Gölge 2022-01-25 09:28:48 +00:00
parent 34c4be5e49
commit 1932401e8d
1 changed files with 32 additions and 35 deletions

View File

@ -1,4 +1,5 @@
import collections
from email.mime import audio
import os
import random
from typing import Dict, List, Union
@ -140,8 +141,6 @@ class TTSDataset(Dataset):
self.pitch_computed = False
self.tokenizer = tokenizer
self.audio_lengths, self.text_lengths = self.compute_lengths(self.samples)
if self.tokenizer.use_phonemes:
self.phoneme_dataset = PhonemeDataset(
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
@ -253,16 +252,14 @@ class TTSDataset(Dataset):
return sample
@staticmethod
def compute_lengths(samples):
audio_lengths = []
text_lengths = []
def _compute_lengths(samples):
new_samples = []
for item in samples:
text, wav_file, *_ = _parse_sample(item)
audio_lengths.append(os.path.getsize(wav_file) / 16 * 8) # assuming 16bit audio
text_lengths.append(len(text))
audio_lengths = np.array(audio_lengths)
text_lengths = np.array(text_lengths)
return audio_lengths, text_lengths
audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
text_lenght = len(text)
new_samples += [item + [audio_length, text_lenght]]
return new_samples
@staticmethod
def filter_by_length(lengths: List[int], min_len: int, max_len: int):
@ -278,8 +275,9 @@ class TTSDataset(Dataset):
return ignore_idx, keep_idx
@staticmethod
def sort_by_length(lengths: List[int]):
idxs = np.argsort(lengths) # ascending order
def sort_by_length(samples: List[List]):
audio_lengths = [s[-2] for s in samples]
idxs = np.argsort(audio_lengths) # ascending order
return idxs
@staticmethod
@ -293,39 +291,38 @@ class TTSDataset(Dataset):
samples[offset:end_offset] = temp_items
return samples
def select_samples_by_idx(self, idxs):
samples = []
audio_lengths = []
text_lengths = []
def _select_samples_by_idx(self, idxs, samples):
samples_new = []
for idx in idxs:
samples.append(self.samples[idx])
audio_lengths.append(self.audio_lengths[idx])
text_lengths.append(self.text_lengths[idx])
return samples, audio_lengths, text_lengths
samples_new.append(samples[idx])
return samples_new
def preprocess_samples(self):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
"""
samples = self._compute_lengths(self.samples)
# sort items based on the sequence length in ascending order
text_ignore_idx, text_keep_idx = self.filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len)
text_lengths = [i[-1] for i in samples]
audio_lengths = [i[-2] for i in samples]
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
audio_ignore_idx, audio_keep_idx = self.filter_by_length(
self.audio_lengths, self.min_audio_len, self.max_audio_len
audio_lengths, self.min_audio_len, self.max_audio_len
)
keep_idx = list(set(audio_keep_idx) | set(text_keep_idx))
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx)
samples = self._select_samples_by_idx(keep_idx, samples)
sorted_idxs = self.sort_by_length(audio_lengths)
sorted_idxs = self.sort_by_length(samples)
if self.start_by_longest:
longest_idxs = sorted_idxs[-1]
sorted_idxs[-1] = sorted_idxs[0]
sorted_idxs[0] = longest_idxs
samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs)
samples = self._select_samples_by_idx(sorted_idxs, samples)
if len(samples) == 0:
raise RuntimeError(" [!] No samples left")
@ -337,19 +334,19 @@ class TTSDataset(Dataset):
samples = self.create_buckets(samples, self.batch_group_size)
# update items to the new sorted items
self.samples = samples
self.audio_lengths = audio_lengths
self.text_lengths = text_lengtsh
audio_lengths = [s[-2] for s in samples]
text_lengths = [s[-1] for s in samples]
self.samples = [s[:-2] for s in samples]
if self.verbose:
print(" | > Preprocessing samples")
print(" | > Max text length: {}".format(np.max(self.text_lengths)))
print(" | > Min text length: {}".format(np.min(self.text_lengths)))
print(" | > Avg text length: {}".format(np.mean(self.text_lengths)))
print(" | > Max text length: {}".format(np.max(text_lengths)))
print(" | > Min text length: {}".format(np.min(text_lengths)))
print(" | > Avg text length: {}".format(np.mean(text_lengths)))
print(" | ")
print(" | > Max audio length: {}".format(np.max(self.audio_lengths)))
print(" | > Min audio length: {}".format(np.min(self.audio_lengths)))
print(" | > Avg audio length: {}".format(np.mean(self.audio_lengths)))
print(" | > Max audio length: {}".format(np.max(audio_lengths)))
print(" | > Min audio length: {}".format(np.min(audio_lengths)))
print(" | > Avg audio length: {}".format(np.mean(audio_lengths)))
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
print(" | > Batch group size: {}.".format(self.batch_group_size))