TTS/tests/test_vocoder_gan_datasets.py

109 lines
4.6 KiB
Python
Raw Normal View History

import os
2020-07-16 13:05:36 +00:00
import numpy as np
2020-07-16 13:05:36 +00:00
from tests import get_tests_path, get_tests_input_path, get_tests_output_path
from torch.utils.data import DataLoader
2020-09-09 10:27:23 +00:00
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data
file_path = os.path.dirname(os.path.realpath(__file__))
2020-07-16 13:05:36 +00:00
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
os.makedirs(OUTPATH, exist_ok=True)
2020-07-16 13:05:36 +00:00
C = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
2020-07-16 13:05:36 +00:00
test_data_path = os.path.join(get_tests_path(), "data/ljspeech/")
ok_ljspeech = os.path.exists(test_data_path)
2021-04-08 09:52:35 +00:00
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_pairs, return_segments, use_noise_augment, use_cache, num_workers):
'''Run dataloader with given parameters and check conditions '''
ap = AudioProcessor(**C.audio)
2020-06-02 16:58:10 +00:00
_, train_items = load_wav_data(test_data_path, 10)
dataset = GANDataset(ap,
train_items,
seq_len=seq_len,
hop_len=hop_len,
pad_short=2000,
conv_pad=conv_pad,
2021-04-08 09:52:35 +00:00
return_pairs=return_pairs,
return_segments=return_segments,
use_noise_augment=use_noise_augment,
use_cache=use_cache)
loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True)
max_iter = 10
count_iter = 0
2021-04-08 09:52:35 +00:00
def check_item(feat, wav):
"""Pass a single pair of features and waveform"""
expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2)
# check shapes
assert np.all(feat.shape == expected_feat_shape), f" [!] {feat.shape} vs {expected_feat_shape}"
assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]
# check feature vs audio match
if not use_noise_augment:
for idx in range(batch_size):
audio = wav[idx].squeeze()
feat = feat[idx]
mel = ap.melspectrogram(audio)
# the first 2 and the last 2 frames are skipped due to the padding
# differences in stft
max_diff = abs((feat - mel[:, :feat.shape[-1]])[:, 2:-2]).max()
assert max_diff <= 0, f' [!] {max_diff}'
# return random segments or return the whole audio
if return_segments:
2021-04-08 09:52:35 +00:00
if return_pairs:
for item1, item2 in loader:
feat1, wav1 = item1
feat2, wav2 = item2
2021-04-08 09:52:35 +00:00
check_item(feat1, wav1)
check_item(feat2, wav2)
count_iter += 1
else:
2021-04-08 12:57:46 +00:00
for item1 in loader:
2021-04-08 09:52:35 +00:00
feat1, wav1 = item1
check_item(feat1, wav1)
count_iter += 1
else:
for item in loader:
feat, wav = item
expected_feat_shape = (batch_size, ap.num_mels, (wav.shape[-1] // hop_len) + (conv_pad * 2))
assert np.all(feat.shape == expected_feat_shape), f" [!] {feat.shape} vs {expected_feat_shape}"
assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]
count_iter += 1
if count_iter == max_iter:
break
def test_parametrized_gan_dataset():
''' test dataloader with different parameters '''
params = [
2021-04-08 09:52:35 +00:00
[32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, True, 0],
[32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, True, 4],
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, True, True, 0],
[1, C.audio['hop_length'], C.audio['hop_length'], 0, True, True, True, True, 0],
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, True, True, True, True, 0],
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, True, 0],
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, True, 0],
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, False, True, True, False, 0],
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, False, False, 0],
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, False, False, 0]
]
for param in params:
print(param)
gan_dataset_case(*param)