TTS/tests/data_tests/test_loader.py

216 lines
8.4 KiB
Python
Raw Normal View History

import os
2018-11-02 15:13:51 +00:00
import shutil
2020-08-04 12:07:47 +00:00
import unittest
2020-08-04 12:07:47 +00:00
import numpy as np
import torch
from torch.utils.data import DataLoader
2020-08-04 12:07:47 +00:00
2021-05-10 21:13:52 +00:00
from tests import get_tests_output_path
2021-05-10 21:03:21 +00:00
from TTS.tts.configs import BaseTTSConfig
2020-09-09 10:27:23 +00:00
from TTS.tts.datasets import TTSDataset
2021-05-31 08:07:12 +00:00
from TTS.tts.datasets.formatters import ljspeech
2020-09-09 10:27:23 +00:00
from TTS.utils.audio import AudioProcessor
2021-04-08 23:17:15 +00:00
# pylint: disable=unused-variable
2019-07-19 09:48:12 +00:00
2020-07-16 13:05:36 +00:00
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
2018-11-02 15:13:51 +00:00
os.makedirs(OUTPATH, exist_ok=True)
2021-05-10 13:27:23 +00:00
# create a dummy config for testing data loaders.
2021-05-10 21:03:21 +00:00
c = BaseTTSConfig(text_cleaner="english_cleaners", num_loader_workers=0, batch_size=2)
2021-05-10 13:27:23 +00:00
c.r = 5
c.data_path = "tests/data/ljspeech/"
2018-11-02 15:13:51 +00:00
ok_ljspeech = os.path.exists(c.data_path)
2018-11-02 15:41:56 +00:00
DATA_EXIST = True
if not os.path.exists(c.data_path):
DATA_EXIST = False
2018-04-03 10:24:57 +00:00
2018-12-17 15:35:52 +00:00
print(" > Dynamic data loader test: {}".format(DATA_EXIST))
2019-04-29 09:07:04 +00:00
2018-12-17 15:35:52 +00:00
2018-11-02 15:13:51 +00:00
class TestTTSDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
2021-04-08 22:38:08 +00:00
super().__init__(*args, **kwargs)
self.max_loader_iter = 4
2018-11-02 15:13:51 +00:00
self.ap = AudioProcessor(**c.audio)
def _create_dataloader(self, batch_size, r, bgs):
2021-04-08 23:17:15 +00:00
items = ljspeech(c.data_path, "metadata.csv")
2021-05-20 16:22:52 +00:00
dataset = TTSDataset.TTSDataset(
2018-11-02 15:13:51 +00:00
r,
c.text_cleaner,
compute_linear_spec=True,
2018-11-02 15:13:51 +00:00
ap=self.ap,
2020-03-02 18:33:13 +00:00
meta_data=items,
2021-05-10 13:27:23 +00:00
tp=c.characters,
2018-11-02 15:13:51 +00:00
batch_group_size=bgs,
2019-02-25 17:34:06 +00:00
min_seq_len=c.min_seq_len,
max_seq_len=float("inf"),
2021-04-08 23:17:15 +00:00
use_phonemes=False,
)
2018-11-02 15:13:51 +00:00
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
2021-04-08 23:17:15 +00:00
num_workers=c.num_loader_workers,
)
2018-11-02 15:13:51 +00:00
return dataloader, dataset
def test_loader(self):
2018-08-13 13:02:17 +00:00
if ok_ljspeech:
2018-11-02 15:13:51 +00:00
dataloader, dataset = self._create_dataloader(2, c.r, 0)
2018-08-13 13:02:17 +00:00
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
2019-07-19 09:12:48 +00:00
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
2018-08-13 13:02:17 +00:00
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
2021-04-08 23:17:15 +00:00
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
2018-08-13 13:02:17 +00:00
# TODO: more assertion here
2020-08-04 12:07:47 +00:00
assert isinstance(speaker_name[0], str)
2018-08-13 13:02:17 +00:00
assert linear_input.shape[0] == c.batch_size
2020-07-08 08:21:45 +00:00
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
2018-08-13 13:02:17 +00:00
assert mel_input.shape[0] == c.batch_size
2021-04-08 23:17:15 +00:00
assert mel_input.shape[2] == c.audio["num_mels"]
2018-11-02 15:13:51 +00:00
# check normalization ranges
if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm
2021-04-08 23:17:15 +00:00
assert mel_input.min() >= -self.ap.max_norm # pylint: disable=invalid-unary-operand-type
2018-11-02 15:13:51 +00:00
assert mel_input.min() < 0
else:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= 0
2018-09-20 09:08:12 +00:00
def test_batch_group_shuffle(self):
if ok_ljspeech:
2018-11-02 15:13:51 +00:00
dataloader, dataset = self._create_dataloader(2, c.r, 16)
last_length = 0
frames = dataset.items
2018-09-20 09:08:12 +00:00
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
2019-07-19 09:12:48 +00:00
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
2018-09-20 09:08:12 +00:00
2018-11-02 15:13:51 +00:00
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
dataloader.dataset.sort_items()
2019-07-21 00:09:23 +00:00
is_items_reordered = False
for idx, item in enumerate(dataloader.dataset.items):
if item != frames[idx]:
is_items_reordered = True
break
assert is_items_reordered
2018-09-20 09:08:12 +00:00
2018-11-02 15:13:51 +00:00
def test_padding_and_spec(self):
2018-08-13 13:02:17 +00:00
if ok_ljspeech:
2018-11-02 15:13:51 +00:00
dataloader, dataset = self._create_dataloader(1, 1, 0)
2018-08-13 13:02:17 +00:00
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
2019-07-19 09:12:48 +00:00
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
2018-08-13 13:02:17 +00:00
2018-11-02 15:13:51 +00:00
# check mel_spec consistency
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
2021-04-08 23:17:15 +00:00
mel = self.ap.melspectrogram(wav).astype("float32")
mel = torch.FloatTensor(mel).contiguous()
2019-08-16 13:08:04 +00:00
mel_dl = mel_input[0]
# NOTE: Below needs to check == 0 but due to an unknown reason
# there is a slight difference between two matrices.
# TODO: Check this assert cond more in detail.
2020-02-13 16:23:37 +00:00
assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T - mel_dl).max()
2018-11-02 15:13:51 +00:00
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
2020-03-10 10:06:25 +00:00
wav = self.ap.inv_melspectrogram(mel_spec.T)
2021-04-08 23:17:15 +00:00
self.ap.save_wav(wav, OUTPATH + "/mel_inv_dataloader.wav")
shutil.copy(item_idx[0], OUTPATH + "/mel_target_dataloader.wav")
2018-11-02 15:13:51 +00:00
2019-07-19 06:46:23 +00:00
# check linear-spec
2018-11-02 15:13:51 +00:00
linear_spec = linear_input[0].cpu().numpy()
wav = self.ap.inv_spectrogram(linear_spec.T)
2021-04-08 23:17:15 +00:00
self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav")
shutil.copy(item_idx[0], OUTPATH + "/linear_target_dataloader.wav")
2018-11-02 15:13:51 +00:00
2018-08-13 13:02:17 +00:00
# check the last time step to be zero padded
assert linear_input[0, -1].sum() != 0
2018-08-13 13:02:17 +00:00
assert linear_input[0, -2].sum() != 0
assert mel_input[0, -1].sum() != 0
2018-11-02 15:13:51 +00:00
assert mel_input[0, -2].sum() != 0
2018-08-13 13:02:17 +00:00
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
2018-11-02 15:13:51 +00:00
assert mel_lengths[0] == linear_input[0].shape[0]
2018-08-13 13:02:17 +00:00
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
2018-11-02 15:13:51 +00:00
dataloader, dataset = self._create_dataloader(2, 1, 0)
2018-08-13 13:02:17 +00:00
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
2019-07-19 09:12:48 +00:00
speaker_name = data[2]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_target = data[6]
item_idx = data[7]
2018-08-13 13:02:17 +00:00
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# check the first item in the batch
assert linear_input[idx, -1].sum() != 0
2018-11-02 15:13:51 +00:00
assert linear_input[idx, -2].sum() != 0, linear_input
assert mel_input[idx, -1].sum() != 0
2018-08-13 13:02:17 +00:00
assert mel_input[idx, -2].sum() != 0, mel_input
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
2018-11-02 15:13:51 +00:00
assert mel_lengths[idx] == linear_input[idx].shape[0]
2018-08-13 13:02:17 +00:00
# check the second itme in the batch
assert linear_input[1 - idx, -1].sum() == 0
2018-11-02 15:13:51 +00:00
assert mel_input[1 - idx, -1].sum() == 0
2021-04-08 23:17:15 +00:00
assert stop_target[1, mel_lengths[1] - 1] == 1
assert stop_target[1, mel_lengths[1] :].sum() == 0
2018-08-13 13:02:17 +00:00
assert len(mel_lengths.shape) == 1
# check batch zero-frame conditions (zero-frame disabled)
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0