Add option to sort input sequnce by audio len

pull/725/head
Eren Gölge 2021-08-30 08:02:59 +00:00
parent 695a6439d3
commit f186856e5d
8 changed files with 29 additions and 12 deletions

View File

@ -97,7 +97,7 @@ Example run:
enable_eos_bos=C.enable_eos_bos_chars,
)
dataset.sort_items()
dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False))
loader = DataLoader(
dataset,
batch_size=args.batch_size,

View File

@ -46,7 +46,7 @@ def setup_loader(ap, r, verbose=False):
if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(c.num_loader_workers)
dataset.sort_items()
dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False))
loader = DataLoader(
dataset,

View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
import importlib
import logging
import multiprocessing
import os
import platform

View File

@ -120,8 +120,9 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec: bool = True
# overrides
min_seq_len: int = 32
max_seq_len: int = 1000
sort_by_audio_len: bool = True
min_seq_len: int = 0
max_seq_len: int = 500000
r: int = 1 # DO NOT CHANGE
add_blank: bool = True

View File

@ -69,7 +69,7 @@ class TTSDataset(Dataset):
batch. Set 0 to disable. Defaults to 0.
min_seq_len (int): Minimum input sequence length to be processed
by the loader. Filter out input sequences that are shorter than this. Some models have a
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a
minimum input length due to its architecture. Defaults to 0.
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
@ -302,10 +302,23 @@ class TTSDataset(Dataset):
for idx, p in enumerate(phonemes):
self.items[idx][0] = p
def sort_items(self):
r"""Sort instances based on text length in ascending order"""
lengths = np.array([len(ins[0]) for ins in self.items])
def sort_and_filter_items(self, by_audio_len=False):
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
range.
Args:
by_audio_len (bool): if True, sort by audio length else by text length.
"""
# compute the target sequence length
if by_audio_len:
lengths = []
for item in self.items:
lengths.append(os.path.getsize(item[1]))
lengths = np.array(lengths)
else:
lengths = np.array([len(ins[0]) for ins in self.items])
# sort items based on the sequence length in ascending order
idxs = np.argsort(lengths)
new_items = []
ignored = []
@ -315,7 +328,10 @@ class TTSDataset(Dataset):
ignored.append(idx)
else:
new_items.append(self.items[idx])
# shuffle batch groups
# create batches with similar length items
# the larger the `batch_group_size`, the higher the length variety in a batch.
if self.batch_group_size > 0:
for i in range(len(new_items) // self.batch_group_size):
offset = i * self.batch_group_size
@ -325,6 +341,7 @@ class TTSDataset(Dataset):
new_items[offset:end_offset] = temp_items
self.items = new_items
# logging
if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths)))

View File

@ -243,7 +243,7 @@ class BaseTTS(BaseModel):
dist.barrier()
# sort input sequences from short to long
dataset.sort_items()
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None

View File

@ -43,7 +43,7 @@ config = VitsConfig(
print_step=25,
print_eval=True,
mixed_precision=True,
max_seq_len=5000,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],
)

View File

@ -124,7 +124,7 @@ class TestTTSDataset(unittest.TestCase):
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
dataloader.dataset.sort_items()
dataloader.dataset.sort_and_filter_items()
is_items_reordered = False
for idx, item in enumerate(dataloader.dataset.items):
if item != frames[idx]: