mirror of https://github.com/coqui-ai/TTS.git
Implement `start_by_longest` option for TTSDatase
parent
c4c471d61d
commit
ef63c99524
|
@ -172,6 +172,10 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
use_noise_augment (bool):
|
||||
Augment the input audio with random noise.
|
||||
|
||||
start_by_longest (bool):
|
||||
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
|
||||
Defaults to False.
|
||||
|
||||
add_blank (bool):
|
||||
Add blank characters between each other two characters. It improves performance for some models at expense
|
||||
of slower run-time due to the longer input sequence.
|
||||
|
@ -231,6 +235,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
compute_linear_spec: bool = False
|
||||
precompute_num_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
start_by_longest: bool = False
|
||||
# dataset
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# optimizer
|
||||
|
|
|
@ -67,15 +67,6 @@ class VitsConfig(BaseTTSConfig):
|
|||
compute_linear_spec (bool):
|
||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||
|
||||
sort_by_audio_len (bool):
|
||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `True`.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum sequnce length to be considered for training. Defaults to `0`.
|
||||
|
||||
max_seq_len (int):
|
||||
Maximum sequnce length to be considered for training. Defaults to `500000`.
|
||||
|
||||
r (int):
|
||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||
|
||||
|
@ -123,6 +114,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
feat_loss_alpha: float = 1.0
|
||||
mel_loss_alpha: float = 45.0
|
||||
dur_loss_alpha: float = 1.0
|
||||
aligner_loss_alpha = 1.0
|
||||
speaker_encoder_loss_alpha: float = 1.0
|
||||
|
||||
# data loader params
|
||||
|
@ -130,9 +122,6 @@ class VitsConfig(BaseTTSConfig):
|
|||
compute_linear_spec: bool = True
|
||||
|
||||
# overrides
|
||||
sort_by_audio_len: bool = True
|
||||
min_seq_len: int = 0
|
||||
max_seq_len: int = 500000
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
add_blank: bool = True
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ class TTSDataset(Dataset):
|
|||
d_vector_mapping: Dict = None,
|
||||
language_id_mapping: Dict = None,
|
||||
use_noise_augment: bool = False,
|
||||
start_by_longest: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
||||
|
@ -109,6 +110,8 @@ class TTSDataset(Dataset):
|
|||
|
||||
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
|
||||
|
||||
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
|
||||
|
||||
verbose (bool): Print diagnostic information. Defaults to false.
|
||||
"""
|
||||
super().__init__()
|
||||
|
@ -130,6 +133,7 @@ class TTSDataset(Dataset):
|
|||
self.d_vector_mapping = d_vector_mapping
|
||||
self.language_id_mapping = language_id_mapping
|
||||
self.use_noise_augment = use_noise_augment
|
||||
self.start_by_longest = start_by_longest
|
||||
|
||||
self.verbose = verbose
|
||||
self.rescue_item_idx = 1
|
||||
|
@ -315,6 +319,12 @@ class TTSDataset(Dataset):
|
|||
samples, audio_lengths, _ = self.select_samples_by_idx(keep_idx)
|
||||
|
||||
sorted_idxs = self.sort_by_length(audio_lengths)
|
||||
|
||||
if self.start_by_longest:
|
||||
longest_idxs = sorted_idxs[-1]
|
||||
sorted_idxs[-1] = sorted_idxs[0]
|
||||
sorted_idxs[0] = longest_idxs
|
||||
|
||||
samples, audio_lengths, text_lengtsh = self.select_samples_by_idx(sorted_idxs)
|
||||
|
||||
if len(samples) == 0:
|
||||
|
|
|
@ -290,6 +290,7 @@ class BaseTTS(BaseModel):
|
|||
speaker_id_mapping=speaker_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||
tokenizer=self.tokenizer,
|
||||
start_by_longest=config.start_by_longest,
|
||||
language_id_mapping=language_id_mapping,
|
||||
)
|
||||
|
||||
|
|
|
@ -63,6 +63,7 @@ class TestTTSDataset(unittest.TestCase):
|
|||
max_text_len=c.max_text_len,
|
||||
min_audio_len=c.min_audio_len,
|
||||
max_audio_len=c.max_audio_len,
|
||||
start_by_longest=start_by_longest
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
|
@ -142,6 +143,23 @@ class TestTTSDataset(unittest.TestCase):
|
|||
self.assertGreaterEqual(avg_length, last_length)
|
||||
self.assertTrue(is_items_reordered)
|
||||
|
||||
def test_start_by_longest(self):
|
||||
"""Test start_by_longest option.
|
||||
|
||||
Ther first item of the fist batch must be longer than all the other items.
|
||||
"""
|
||||
if ok_ljspeech:
|
||||
dataloader, _ = self._create_dataloader(2, c.r, 0, True)
|
||||
dataloader.dataset.preprocess_samples()
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
mel_lengths = data["mel_lengths"]
|
||||
if i == 0:
|
||||
max_len = mel_lengths[0]
|
||||
print(mel_lengths)
|
||||
self.assertTrue(all(max_len >= mel_lengths))
|
||||
|
||||
def test_padding_and_spectrograms(self):
|
||||
def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths):
|
||||
self.assertNotEqual(linear_input[idx, -1].sum(), 0) # check padding
|
||||
|
|
Loading…
Reference in New Issue