Implement `start_by_longest` option for TTSDatase

pull/1324/head
Eren Gölge 2022-01-21 15:29:06 +00:00
parent c4c471d61d
commit ef63c99524
5 changed files with 35 additions and 12 deletions

View File

@ -172,6 +172,10 @@ class BaseTTSConfig(BaseTrainingConfig):
use_noise_augment (bool):
Augment the input audio with random noise.
start_by_longest (bool):
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
Defaults to False.
add_blank (bool):
Add blank characters between each other two characters. It improves performance for some models at expense
of slower run-time due to the longer input sequence.
@ -231,6 +235,7 @@ class BaseTTSConfig(BaseTrainingConfig):
compute_linear_spec: bool = False
precompute_num_workers: int = 0
use_noise_augment: bool = False
start_by_longest: bool = False
# dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer

View File

@ -67,15 +67,6 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
sort_by_audio_len (bool):
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
min_seq_len (int):
Minimum sequnce length to be considered for training. Defaults to `0`.
max_seq_len (int):
Maximum sequnce length to be considered for training. Defaults to `500000`.
r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
@ -123,6 +114,7 @@ class VitsConfig(BaseTTSConfig):
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
aligner_loss_alpha = 1.0
speaker_encoder_loss_alpha: float = 1.0
# data loader params
@ -130,9 +122,6 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec: bool = True
# overrides
sort_by_audio_len: bool = True
min_seq_len: int = 0
max_seq_len: int = 500000
r: int = 1 # DO NOT CHANGE
add_blank: bool = True

View File

@ -56,6 +56,7 @@ class TTSDataset(Dataset):
d_vector_mapping: Dict = None,
language_id_mapping: Dict = None,
use_noise_augment: bool = False,
start_by_longest: bool = False,
verbose: bool = False,
):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
@ -109,6 +110,8 @@ class TTSDataset(Dataset):
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
verbose (bool): Print diagnostic information. Defaults to false.
"""
super().__init__()
@ -130,6 +133,7 @@ class TTSDataset(Dataset):
self.d_vector_mapping = d_vector_mapping
self.language_id_mapping = language_id_mapping
self.use_noise_augment = use_noise_augment
self.start_by_longest = start_by_longest
self.verbose = verbose
self.rescue_item_idx = 1
@ -315,6 +319,12 @@ class TTSDataset(Dataset):
samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx)
sorted_idxs = self.sort_by_length(audio_lengths)
if self.start_by_longest:
longest_idxs = sorted_idxs[-1]
sorted_idxs[-1] = sorted_idxs[0]
sorted_idxs[0] = longest_idxs
samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs)
if len(samples) == 0:

View File

@ -290,6 +290,7 @@ class BaseTTS(BaseModel):
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest,
language_id_mapping=language_id_mapping,
)

View File

@ -63,6 +63,7 @@ class TestTTSDataset(unittest.TestCase):
max_text_len=c.max_text_len,
min_audio_len=c.min_audio_len,
max_audio_len=c.max_audio_len,
start_by_longest=start_by_longest
)
dataloader = DataLoader(
dataset,
@ -142,6 +143,23 @@ class TestTTSDataset(unittest.TestCase):
self.assertGreaterEqual(avg_length, last_length)
self.assertTrue(is_items_reordered)
def test_start_by_longest(self):
"""Test start_by_longest option.
Ther first item of the fist batch must be longer than all the other items.
"""
if ok_ljspeech:
dataloader, _ = self._create_dataloader(2, c.r, 0, True)
dataloader.dataset.preprocess_samples()
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
mel_lengths = data["mel_lengths"]
if i == 0:
max_len = mel_lengths[0]
print(mel_lengths)
self.assertTrue(all(max_len >= mel_lengths))
def test_padding_and_spectrograms(self):
def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths):
self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding