mirror of https://github.com/coqui-ai/TTS.git
110 lines
4.3 KiB
Python
110 lines
4.3 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
from torch.utils.data import DataLoader
|
|
|
|
from tests import get_tests_output_path, get_tests_path
|
|
from TTS.utils.audio import AudioProcessor
|
|
from TTS.vocoder.configs import BaseGANVocoderConfig
|
|
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
|
from TTS.vocoder.datasets.preprocess import load_wav_data
|
|
|
|
file_path = os.path.dirname(os.path.realpath(__file__))
|
|
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
|
|
os.makedirs(OUTPATH, exist_ok=True)
|
|
|
|
C = BaseGANVocoderConfig()
|
|
|
|
test_data_path = os.path.join(get_tests_path(), "data/ljspeech/")
|
|
ok_ljspeech = os.path.exists(test_data_path)
|
|
|
|
|
|
def gan_dataset_case(
|
|
batch_size, seq_len, hop_len, conv_pad, return_pairs, return_segments, use_noise_augment, use_cache, num_workers
|
|
):
|
|
"""Run dataloader with given parameters and check conditions"""
|
|
ap = AudioProcessor(**C.audio)
|
|
_, train_items = load_wav_data(test_data_path, 10)
|
|
dataset = GANDataset(
|
|
ap,
|
|
train_items,
|
|
seq_len=seq_len,
|
|
hop_len=hop_len,
|
|
pad_short=2000,
|
|
conv_pad=conv_pad,
|
|
return_pairs=return_pairs,
|
|
return_segments=return_segments,
|
|
use_noise_augment=use_noise_augment,
|
|
use_cache=use_cache,
|
|
)
|
|
loader = DataLoader(
|
|
dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True
|
|
)
|
|
|
|
max_iter = 10
|
|
count_iter = 0
|
|
|
|
def check_item(feat, wav):
|
|
"""Pass a single pair of features and waveform"""
|
|
feat = feat.numpy()
|
|
wav = wav.numpy()
|
|
expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2)
|
|
|
|
# check shapes
|
|
assert np.all(feat.shape == expected_feat_shape), f" [!] {feat.shape} vs {expected_feat_shape}"
|
|
assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]
|
|
|
|
# check feature vs audio match
|
|
if not use_noise_augment:
|
|
for idx in range(batch_size):
|
|
audio = wav[idx].squeeze()
|
|
feat = feat[idx]
|
|
mel = ap.melspectrogram(audio)
|
|
# the first 2 and the last 2 frames are skipped due to the padding
|
|
# differences in stft
|
|
max_diff = abs((feat - mel[:, : feat.shape[-1]])[:, 2:-2]).max()
|
|
assert max_diff <= 1e-6, f" [!] {max_diff}"
|
|
|
|
# return random segments or return the whole audio
|
|
if return_segments:
|
|
if return_pairs:
|
|
for item1, item2 in loader:
|
|
feat1, wav1 = item1
|
|
feat2, wav2 = item2
|
|
check_item(feat1, wav1)
|
|
check_item(feat2, wav2)
|
|
count_iter += 1
|
|
else:
|
|
for item1 in loader:
|
|
feat1, wav1 = item1
|
|
check_item(feat1, wav1)
|
|
count_iter += 1
|
|
else:
|
|
for item in loader:
|
|
feat, wav = item
|
|
expected_feat_shape = (batch_size, ap.num_mels, (wav.shape[-1] // hop_len) + (conv_pad * 2))
|
|
assert np.all(feat.shape == expected_feat_shape), f" [!] {feat.shape} vs {expected_feat_shape}"
|
|
assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]
|
|
count_iter += 1
|
|
if count_iter == max_iter:
|
|
break
|
|
|
|
|
|
def test_parametrized_gan_dataset():
|
|
"""test dataloader with different parameters"""
|
|
params = [
|
|
[32, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, False, True, 0],
|
|
[32, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, False, True, 4],
|
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, True, True, 0],
|
|
[1, C.audio["hop_length"], C.audio["hop_length"], 0, True, True, True, True, 0],
|
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, True, True, True, True, 0],
|
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, True, True, 0],
|
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, False, True, 0],
|
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, False, True, True, False, 0],
|
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, False, False, 0],
|
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, False, False, 0],
|
|
]
|
|
for param in params:
|
|
print(param)
|
|
gan_dataset_case(*param)
|