From 89dded8964492d1d392c0378955557c914ae2f0e Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 17 Apr 2018 09:56:31 -0700 Subject: [PATCH] Add TWEB data loader tests --- tests/loader_tests.py | 146 +++++++++++++++++++++++++++++++++++++++-- tests/test_config.json | 5 +- 2 files changed, 143 insertions(+), 8 deletions(-) diff --git a/tests/loader_tests.py b/tests/loader_tests.py index 678b243a..76d82557 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -5,21 +5,22 @@ import numpy as np from torch.utils.data import DataLoader from TTS.utils.generic_utils import load_config from TTS.datasets.LJSpeech import LJSpeechDataset +from TTS.datasets.TWEB import TWEBDataset file_path = os.path.dirname(os.path.realpath(__file__)) c = load_config(os.path.join(file_path, 'test_config.json')) -class TestDataset(unittest.TestCase): +class TestLJSpeechDataset(unittest.TestCase): def __init__(self, *args, **kwargs): - super(TestDataset, self).__init__(*args, **kwargs) + super(TestLJSpeechDataset, self).__init__(*args, **kwargs) self.max_loader_iter = 4 def test_loader(self): - dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'), - os.path.join(c.data_path, 'wavs'), + dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), + os.path.join(c.data_path_LJSpeech, 'wavs'), c.r, c.sample_rate, c.text_cleaner, @@ -58,8 +59,8 @@ class TestDataset(unittest.TestCase): assert mel_input.shape[2] == c.num_mels def test_padding(self): - dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'), - os.path.join(c.data_path, 'wavs'), + dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), + os.path.join(c.data_path_LJSpeech, 'wavs'), 1, c.sample_rate, c.text_cleaner, @@ -141,3 +142,136 @@ class TestDataset(unittest.TestCase): # check batch conditions assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 + + +class TestTWEBDataset(unittest.TestCase): + + def __init__(self, *args, **kwargs): + super(TestTWEBDataset, self).__init__(*args, **kwargs) + self.max_loader_iter = 4 + + def test_loader(self): + dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), + os.path.join(c.data_path_TWEB, 'wavs'), + c.r, + c.sample_rate, + c.text_cleaner, + c.num_mels, + c.min_level_db, + c.frame_shift_ms, + c.frame_length_ms, + c.preemphasis, + c.ref_level_db, + c.num_freq, + c.power + ) + + dataloader = DataLoader(dataset, batch_size=2, + shuffle=True, collate_fn=dataset.collate_fn, + drop_last=True, num_workers=c.num_loader_workers) + + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + + neg_values = text_input[text_input < 0] + check_count = len(neg_values) + assert check_count == 0, \ + " !! Negative values in text_input: {}".format(check_count) + # TODO: more assertion here + assert linear_input.shape[0] == c.batch_size + assert mel_input.shape[0] == c.batch_size + assert mel_input.shape[2] == c.num_mels + + def test_padding(self): + dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), + os.path.join(c.data_path_TWEB, 'wavs'), + 1, + c.sample_rate, + c.text_cleaner, + c.num_mels, + c.min_level_db, + c.frame_shift_ms, + c.frame_length_ms, + c.preemphasis, + c.ref_level_db, + c.num_freq, + c.power + ) + + # Test for batch size 1 + dataloader = DataLoader(dataset, batch_size=1, + shuffle=False, collate_fn=dataset.collate_fn, + drop_last=False, num_workers=c.num_loader_workers) + + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + + # check the last time step to be zero padded + assert mel_input[0, -1].sum() == 0 + assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i) + assert linear_input[0, -1].sum() == 0 + assert linear_input[0, -2].sum() != 0 + assert stop_target[0, -1] == 1 + assert stop_target[0, -2] == 0 + assert stop_target.sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[0] == mel_input[0].shape[0] + + # Test for batch size 2 + dataloader = DataLoader(dataset, batch_size=2, + shuffle=False, collate_fn=dataset.collate_fn, + drop_last=False, num_workers=c.num_loader_workers) + + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + + if mel_lengths[0] > mel_lengths[1]: + idx = 0 + else: + idx = 1 + + # check the first item in the batch + assert mel_input[idx, -1].sum() == 0 + assert mel_input[idx, -2].sum() != 0, mel_input + assert linear_input[idx, -1].sum() == 0 + assert linear_input[idx, -2].sum() != 0 + assert stop_target[idx, -1] == 1 + assert stop_target[idx, -2] == 0 + assert stop_target[idx].sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[idx] == mel_input[idx].shape[0] + + # check the second itme in the batch + assert mel_input[1-idx, -1].sum() == 0 + assert linear_input[1-idx, -1].sum() == 0 + assert stop_target[1-idx, -1] == 1 + assert len(mel_lengths.shape) == 1 + + # check batch conditions + assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 + assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 diff --git a/tests/test_config.json b/tests/test_config.json index 420f16d3..2c2be17e 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -1,7 +1,7 @@ { "num_mels": 80, "num_freq": 1025, - "sample_rate": 20000, + "sample_rate": 22050, "frame_length_ms": 50, "frame_shift_ms": 12.5, "preemphasis": 0.97, @@ -24,7 +24,8 @@ "num_loader_workers": 4, "save_step": 200, - "data_path": "/data/shared/KeithIto/LJSpeech-1.0", + "data_path_LJSpeech": "/data/shared/KeithIto/LJSpeech-1.0", + "data_path_TWEB": "/data/shared/BibleSpeech", "output_path": "result", "log_dir": "/home/erogol/projects/TTS/logs/" }