update loader_tests.py

pull/10/head
Eren Golge 2018-12-17 16:35:52 +01:00
parent be6e46798b
commit ae5e8b2b18
1 changed files with 15 additions and 4 deletions

View File

@ -23,6 +23,9 @@ if not os.path.exists(c.data_path_cache):
if not os.path.exists(c.data_path):
DATA_EXIST = False
print(" > Dynamic data loader test: {}".format(DATA_EXIST))
print(" > Cache data loader test: {}".format(CACHE_EXIST))
class TestTTSDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestTTSDataset, self).__init__(*args, **kwargs)
@ -199,7 +202,7 @@ class TestTTSDatasetCached(unittest.TestCase):
def _create_dataloader(self, batch_size, r, bgs):
dataset = TTSDatasetCached.MyDataset(
dataset = TTSDataset.MyDataset(
c.data_path_cache,
'tts_metadata.csv',
r,
@ -207,7 +210,9 @@ class TestTTSDatasetCached(unittest.TestCase):
preprocessor=tts_cache,
ap=self.ap,
batch_group_size=bgs,
min_seq_len=c.min_seq_len)
min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len,
cached=True)
dataloader = DataLoader(
dataset,
@ -299,11 +304,17 @@ class TestTTSDatasetCached(unittest.TestCase):
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum()
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
mel_spec = mel_input[-1].cpu().numpy()
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
self.ap.save_wav(wav,
OUTPATH + '/mel_inv_dataloader_cache.wav')
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_cache.wav')
shutil.copy(item_idx[-1], OUTPATH + '/mel_target_dataloader_cache.wav')
# check linear-spec
linear_spec = linear_input[-1].cpu().numpy()
wav = self.ap.inv_spectrogram(linear_spec.T)
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader_cache.wav')
shutil.copy(item_idx[-1], OUTPATH + '/linear_target_dataloader_cache.wav')
# check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0