mirror of https://github.com/coqui-ai/TTS.git
Fix dataset preprocessing
parent
34c4be5e49
commit
1932401e8d
|
@ -1,4 +1,5 @@
|
|||
import collections
|
||||
from email.mime import audio
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Union
|
||||
|
@ -140,8 +141,6 @@ class TTSDataset(Dataset):
|
|||
self.pitch_computed = False
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.audio_lengths, self.text_lengths = self.compute_lengths(self.samples)
|
||||
|
||||
if self.tokenizer.use_phonemes:
|
||||
self.phoneme_dataset = PhonemeDataset(
|
||||
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
|
||||
|
@ -253,16 +252,14 @@ class TTSDataset(Dataset):
|
|||
return sample
|
||||
|
||||
@staticmethod
|
||||
def compute_lengths(samples):
|
||||
audio_lengths = []
|
||||
text_lengths = []
|
||||
def _compute_lengths(samples):
|
||||
new_samples = []
|
||||
for item in samples:
|
||||
text, wav_file, *_ = _parse_sample(item)
|
||||
audio_lengths.append(os.path.getsize(wav_file) / 16 * 8) # assuming 16bit audio
|
||||
text_lengths.append(len(text))
|
||||
audio_lengths = np.array(audio_lengths)
|
||||
text_lengths = np.array(text_lengths)
|
||||
return audio_lengths, text_lengths
|
||||
audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
|
||||
text_lenght = len(text)
|
||||
new_samples += [item + [audio_length, text_lenght]]
|
||||
return new_samples
|
||||
|
||||
@staticmethod
|
||||
def filter_by_length(lengths: List[int], min_len: int, max_len: int):
|
||||
|
@ -278,8 +275,9 @@ class TTSDataset(Dataset):
|
|||
return ignore_idx, keep_idx
|
||||
|
||||
@staticmethod
|
||||
def sort_by_length(lengths: List[int]):
|
||||
idxs = np.argsort(lengths) # ascending order
|
||||
def sort_by_length(samples: List[List]):
|
||||
audio_lengths = [s[-2] for s in samples]
|
||||
idxs = np.argsort(audio_lengths) # ascending order
|
||||
return idxs
|
||||
|
||||
@staticmethod
|
||||
|
@ -293,39 +291,38 @@ class TTSDataset(Dataset):
|
|||
samples[offset:end_offset] = temp_items
|
||||
return samples
|
||||
|
||||
def select_samples_by_idx(self, idxs):
|
||||
samples = []
|
||||
audio_lengths = []
|
||||
text_lengths = []
|
||||
def _select_samples_by_idx(self, idxs, samples):
|
||||
samples_new = []
|
||||
for idx in idxs:
|
||||
samples.append(self.samples[idx])
|
||||
audio_lengths.append(self.audio_lengths[idx])
|
||||
text_lengths.append(self.text_lengths[idx])
|
||||
return samples, audio_lengths, text_lengths
|
||||
samples_new.append(samples[idx])
|
||||
return samples_new
|
||||
|
||||
def preprocess_samples(self):
|
||||
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
||||
range.
|
||||
"""
|
||||
samples = self._compute_lengths(self.samples)
|
||||
|
||||
# sort items based on the sequence length in ascending order
|
||||
text_ignore_idx, text_keep_idx = self.filter_by_length(self.text_lengths, self.min_text_len, self.max_text_len)
|
||||
text_lengths = [i[-1] for i in samples]
|
||||
audio_lengths = [i[-2] for i in samples]
|
||||
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
|
||||
audio_ignore_idx, audio_keep_idx = self.filter_by_length(
|
||||
self.audio_lengths, self.min_audio_len, self.max_audio_len
|
||||
audio_lengths, self.min_audio_len, self.max_audio_len
|
||||
)
|
||||
keep_idx = list(set(audio_keep_idx) | set(text_keep_idx))
|
||||
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
|
||||
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
|
||||
|
||||
samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx)
|
||||
samples = self._select_samples_by_idx(keep_idx, samples)
|
||||
|
||||
sorted_idxs = self.sort_by_length(audio_lengths)
|
||||
sorted_idxs = self.sort_by_length(samples)
|
||||
|
||||
if self.start_by_longest:
|
||||
longest_idxs = sorted_idxs[-1]
|
||||
sorted_idxs[-1] = sorted_idxs[0]
|
||||
sorted_idxs[0] = longest_idxs
|
||||
|
||||
samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs)
|
||||
samples = self._select_samples_by_idx(sorted_idxs, samples)
|
||||
|
||||
if len(samples) == 0:
|
||||
raise RuntimeError(" [!] No samples left")
|
||||
|
@ -337,19 +334,19 @@ class TTSDataset(Dataset):
|
|||
samples = self.create_buckets(samples, self.batch_group_size)
|
||||
|
||||
# update items to the new sorted items
|
||||
self.samples = samples
|
||||
self.audio_lengths = audio_lengths
|
||||
self.text_lengths = text_lengtsh
|
||||
audio_lengths = [s[-2] for s in samples]
|
||||
text_lengths = [s[-1] for s in samples]
|
||||
self.samples = [s[:-2] for s in samples]
|
||||
|
||||
if self.verbose:
|
||||
print(" | > Preprocessing samples")
|
||||
print(" | > Max text length: {}".format(np.max(self.text_lengths)))
|
||||
print(" | > Min text length: {}".format(np.min(self.text_lengths)))
|
||||
print(" | > Avg text length: {}".format(np.mean(self.text_lengths)))
|
||||
print(" | > Max text length: {}".format(np.max(text_lengths)))
|
||||
print(" | > Min text length: {}".format(np.min(text_lengths)))
|
||||
print(" | > Avg text length: {}".format(np.mean(text_lengths)))
|
||||
print(" | ")
|
||||
print(" | > Max audio length: {}".format(np.max(self.audio_lengths)))
|
||||
print(" | > Min audio length: {}".format(np.min(self.audio_lengths)))
|
||||
print(" | > Avg audio length: {}".format(np.mean(self.audio_lengths)))
|
||||
print(" | > Max audio length: {}".format(np.max(audio_lengths)))
|
||||
print(" | > Min audio length: {}".format(np.min(audio_lengths)))
|
||||
print(" | > Avg audio length: {}".format(np.mean(audio_lengths)))
|
||||
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
|
||||
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
||||
|
||||
|
|
Loading…
Reference in New Issue