import os import unittest import numpy as np from torch.utils.data import DataLoader from TTS.utils.generic_utils import load_config from TTS.datasets.LJSpeech import LJSpeechDataset file_path = os.path.dirname(os.path.realpath(__file__)) c = load_config(os.path.join(file_path, 'test_config.json')) class TestLJSpeechDataset(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestLJSpeechDataset, self).__init__(*args, **kwargs) self.max_loader_iter = 4 def test_loader(self): dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), os.path.join(c.data_path_LJSpeech, 'wavs'), c.r, c.sample_rate, c.text_cleaner, c.num_mels, c.min_level_db, c.frame_shift_ms, c.frame_length_ms, c.preemphasis, c.ref_level_db, c.num_freq, c.power ) dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=c.num_loader_workers) for i, data in enumerate(dataloader): if i == self.max_loader_iter: break text_input = data[0] text_lengths = data[1] linear_input = data[2] mel_input = data[3] mel_lengths = data[4] stop_target = data[5] item_idx = data[6] neg_values = text_input[text_input < 0] check_count = len(neg_values) assert check_count == 0, \ " !! Negative values in text_input: {}".format(check_count) # TODO: more assertion here assert linear_input.shape[0] == c.batch_size assert mel_input.shape[0] == c.batch_size assert mel_input.shape[2] == c.num_mels def test_padding(self): dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), os.path.join(c.data_path_LJSpeech, 'wavs'), 1, c.sample_rate, c.text_cleaner, c.num_mels, c.min_level_db, c.frame_shift_ms, c.frame_length_ms, c.preemphasis, c.ref_level_db, c.num_freq, c.power ) # Test for batch size 1 dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=dataset.collate_fn, drop_last=True, num_workers=c.num_loader_workers) for i, data in enumerate(dataloader): if i == self.max_loader_iter: break text_input = data[0] text_lengths = data[1] linear_input = data[2] mel_input = data[3] mel_lengths = data[4] stop_target = data[5] item_idx = data[6] # check the last time step to be zero padded assert mel_input[0, -1].sum() == 0 assert mel_input[0, -2].sum() != 0 assert linear_input[0, -1].sum() == 0 assert linear_input[0, -2].sum() != 0 assert stop_target[0, -1] == 1 assert stop_target[0, -2] == 0 assert stop_target.sum() == 1 assert len(mel_lengths.shape) == 1 assert mel_lengths[0] == mel_input[0].shape[0] # Test for batch size 2 dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=c.num_loader_workers) for i, data in enumerate(dataloader): if i == self.max_loader_iter: break text_input = data[0] text_lengths = data[1] linear_input = data[2] mel_input = data[3] mel_lengths = data[4] stop_target = data[5] item_idx = data[6] if mel_lengths[0] > mel_lengths[1]: idx = 0 else: idx = 1 # check the first item in the batch assert mel_input[idx, -1].sum() == 0 assert mel_input[idx, -2].sum() != 0, mel_input assert linear_input[idx, -1].sum() == 0 assert linear_input[idx, -2].sum() != 0 assert stop_target[idx, -1] == 1 assert stop_target[idx, -2] == 0 assert stop_target[idx].sum() == 1 assert len(mel_lengths.shape) == 1 assert mel_lengths[idx] == mel_input[idx].shape[0] # check the second itme in the batch assert mel_input[1-idx, -1].sum() == 0 assert linear_input[1-idx, -1].sum() == 0 assert stop_target[1-idx, -1] == 1 assert len(mel_lengths.shape) == 1 # check batch conditions assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 # class TestTWEBDataset(unittest.TestCase): # def __init__(self, *args, **kwargs): # super(TestTWEBDataset, self).__init__(*args, **kwargs) # self.max_loader_iter = 4 # def test_loader(self): # dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), # os.path.join(c.data_path_TWEB, 'wavs'), # c.r, # c.sample_rate, # c.text_cleaner, # c.num_mels, # c.min_level_db, # c.frame_shift_ms, # c.frame_length_ms, # c.preemphasis, # c.ref_level_db, # c.num_freq, # c.power # ) # dataloader = DataLoader(dataset, batch_size=2, # shuffle=True, collate_fn=dataset.collate_fn, # drop_last=True, num_workers=c.num_loader_workers) # for i, data in enumerate(dataloader): # if i == self.max_loader_iter: # break # text_input = data[0] # text_lengths = data[1] # linear_input = data[2] # mel_input = data[3] # mel_lengths = data[4] # stop_target = data[5] # item_idx = data[6] # neg_values = text_input[text_input < 0] # check_count = len(neg_values) # assert check_count == 0, \ # " !! Negative values in text_input: {}".format(check_count) # # TODO: more assertion here # assert linear_input.shape[0] == c.batch_size # assert mel_input.shape[0] == c.batch_size # assert mel_input.shape[2] == c.num_mels # def test_padding(self): # dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), # os.path.join(c.data_path_TWEB, 'wavs'), # 1, # c.sample_rate, # c.text_cleaner, # c.num_mels, # c.min_level_db, # c.frame_shift_ms, # c.frame_length_ms, # c.preemphasis, # c.ref_level_db, # c.num_freq, # c.power # ) # # Test for batch size 1 # dataloader = DataLoader(dataset, batch_size=1, # shuffle=False, collate_fn=dataset.collate_fn, # drop_last=False, num_workers=c.num_loader_workers) # for i, data in enumerate(dataloader): # if i == self.max_loader_iter: # break # text_input = data[0] # text_lengths = data[1] # linear_input = data[2] # mel_input = data[3] # mel_lengths = data[4] # stop_target = data[5] # item_idx = data[6] # # check the last time step to be zero padded # assert mel_input[0, -1].sum() == 0 # assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i) # assert linear_input[0, -1].sum() == 0 # assert linear_input[0, -2].sum() != 0 # assert stop_target[0, -1] == 1 # assert stop_target[0, -2] == 0 # assert stop_target.sum() == 1 # assert len(mel_lengths.shape) == 1 # assert mel_lengths[0] == mel_input[0].shape[0] # # Test for batch size 2 # dataloader = DataLoader(dataset, batch_size=2, # shuffle=False, collate_fn=dataset.collate_fn, # drop_last=False, num_workers=c.num_loader_workers) # for i, data in enumerate(dataloader): # if i == self.max_loader_iter: # break # text_input = data[0] # text_lengths = data[1] # linear_input = data[2] # mel_input = data[3] # mel_lengths = data[4] # stop_target = data[5] # item_idx = data[6] # if mel_lengths[0] > mel_lengths[1]: # idx = 0 # else: # idx = 1 # # check the first item in the batch # assert mel_input[idx, -1].sum() == 0 # assert mel_input[idx, -2].sum() != 0, mel_input # assert linear_input[idx, -1].sum() == 0 # assert linear_input[idx, -2].sum() != 0 # assert stop_target[idx, -1] == 1 # assert stop_target[idx, -2] == 0 # assert stop_target[idx].sum() == 1 # assert len(mel_lengths.shape) == 1 # assert mel_lengths[idx] == mel_input[idx].shape[0] # # check the second itme in the batch # assert mel_input[1-idx, -1].sum() == 0 # assert linear_input[1-idx, -1].sum() == 0 # assert stop_target[1-idx, -1] == 1 # assert len(mel_lengths.shape) == 1 # # check batch conditions # assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 # assert (linear_input * stop_target.unsqueeze(2)).sum() == 0