mirror of https://github.com/coqui-ai/TTS.git
Add option to sort input sequnce by audio len
parent
695a6439d3
commit
f186856e5d
|
@ -97,7 +97,7 @@ Example run:
|
|||
enable_eos_bos=C.enable_eos_bos_chars,
|
||||
)
|
||||
|
||||
dataset.sort_items()
|
||||
dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False))
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
|
|
|
@ -46,7 +46,7 @@ def setup_loader(ap, r, verbose=False):
|
|||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(c.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False))
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
|
|
|
@ -120,8 +120,9 @@ class VitsConfig(BaseTTSConfig):
|
|||
compute_linear_spec: bool = True
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 32
|
||||
max_seq_len: int = 1000
|
||||
sort_by_audio_len: bool = True
|
||||
min_seq_len: int = 0
|
||||
max_seq_len: int = 500000
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
add_blank: bool = True
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ class TTSDataset(Dataset):
|
|||
batch. Set 0 to disable. Defaults to 0.
|
||||
|
||||
min_seq_len (int): Minimum input sequence length to be processed
|
||||
by the loader. Filter out input sequences that are shorter than this. Some models have a
|
||||
by sort_inputs`. Filter out input sequences that are shorter than this. Some models have a
|
||||
minimum input length due to its architecture. Defaults to 0.
|
||||
|
||||
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
|
||||
|
@ -302,10 +302,23 @@ class TTSDataset(Dataset):
|
|||
for idx, p in enumerate(phonemes):
|
||||
self.items[idx][0] = p
|
||||
|
||||
def sort_items(self):
|
||||
r"""Sort instances based on text length in ascending order"""
|
||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||
def sort_and_filter_items(self, by_audio_len=False):
|
||||
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
||||
range.
|
||||
|
||||
Args:
|
||||
by_audio_len (bool): if True, sort by audio length else by text length.
|
||||
"""
|
||||
# compute the target sequence length
|
||||
if by_audio_len:
|
||||
lengths = []
|
||||
for item in self.items:
|
||||
lengths.append(os.path.getsize(item[1]))
|
||||
lengths = np.array(lengths)
|
||||
else:
|
||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||
|
||||
# sort items based on the sequence length in ascending order
|
||||
idxs = np.argsort(lengths)
|
||||
new_items = []
|
||||
ignored = []
|
||||
|
@ -315,7 +328,10 @@ class TTSDataset(Dataset):
|
|||
ignored.append(idx)
|
||||
else:
|
||||
new_items.append(self.items[idx])
|
||||
|
||||
# shuffle batch groups
|
||||
# create batches with similar length items
|
||||
# the larger the `batch_group_size`, the higher the length variety in a batch.
|
||||
if self.batch_group_size > 0:
|
||||
for i in range(len(new_items) // self.batch_group_size):
|
||||
offset = i * self.batch_group_size
|
||||
|
@ -325,6 +341,7 @@ class TTSDataset(Dataset):
|
|||
new_items[offset:end_offset] = temp_items
|
||||
self.items = new_items
|
||||
|
||||
# logging
|
||||
if self.verbose:
|
||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||
|
|
|
@ -243,7 +243,7 @@ class BaseTTS(BaseModel):
|
|||
dist.barrier()
|
||||
|
||||
# sort input sequences from short to long
|
||||
dataset.sort_items()
|
||||
dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))
|
||||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
|
|
@ -43,7 +43,7 @@ config = VitsConfig(
|
|||
print_step=25,
|
||||
print_eval=True,
|
||||
mixed_precision=True,
|
||||
max_seq_len=5000,
|
||||
max_seq_len=500000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
|
|
|
@ -124,7 +124,7 @@ class TestTTSDataset(unittest.TestCase):
|
|||
|
||||
avg_length = mel_lengths.numpy().mean()
|
||||
assert avg_length >= last_length
|
||||
dataloader.dataset.sort_items()
|
||||
dataloader.dataset.sort_and_filter_items()
|
||||
is_items_reordered = False
|
||||
for idx, item in enumerate(dataloader.dataset.items):
|
||||
if item != frames[idx]:
|
||||
|
|
Loading…
Reference in New Issue