mirror of https://github.com/coqui-ai/TTS.git
296 lines
11 KiB
Python
296 lines
11 KiB
Python
import os
|
|
import unittest
|
|
import numpy as np
|
|
|
|
from torch.utils.data import DataLoader
|
|
from TTS.utils.generic_utils import load_config
|
|
from TTS.utils.audio import AudioProcessor
|
|
from TTS.datasets import LJSpeech, Kusal
|
|
|
|
file_path = os.path.dirname(os.path.realpath(__file__))
|
|
c = load_config(os.path.join(file_path, 'test_config.json'))
|
|
ok_kusal = os.path.exists(c.data_path_Kusal)
|
|
ok_ljspeech = os.path.exists(c.data_path_LJSpeech)
|
|
|
|
|
|
class TestLJSpeechDataset(unittest.TestCase):
|
|
def __init__(self, *args, **kwargs):
|
|
super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
|
|
self.max_loader_iter = 4
|
|
self.ap = AudioProcessor(
|
|
sample_rate=c.sample_rate,
|
|
num_mels=c.num_mels,
|
|
min_level_db=c.min_level_db,
|
|
frame_shift_ms=c.frame_shift_ms,
|
|
frame_length_ms=c.frame_length_ms,
|
|
ref_level_db=c.ref_level_db,
|
|
num_freq=c.num_freq,
|
|
power=c.power,
|
|
preemphasis=c.preemphasis)
|
|
|
|
def test_loader(self):
|
|
if ok_ljspeech:
|
|
dataset = LJSpeech.MyDataset(
|
|
os.path.join(c.data_path_LJSpeech),
|
|
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
|
c.r,
|
|
c.text_cleaner,
|
|
ap=self.ap,
|
|
min_seq_len=c.min_seq_len)
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=2,
|
|
shuffle=True,
|
|
collate_fn=dataset.collate_fn,
|
|
drop_last=True,
|
|
num_workers=c.num_loader_workers)
|
|
|
|
for i, data in enumerate(dataloader):
|
|
if i == self.max_loader_iter:
|
|
break
|
|
text_input = data[0]
|
|
text_lengths = data[1]
|
|
linear_input = data[2]
|
|
mel_input = data[3]
|
|
mel_lengths = data[4]
|
|
stop_target = data[5]
|
|
item_idx = data[6]
|
|
|
|
neg_values = text_input[text_input < 0]
|
|
check_count = len(neg_values)
|
|
assert check_count == 0, \
|
|
" !! Negative values in text_input: {}".format(check_count)
|
|
# TODO: more assertion here
|
|
assert linear_input.shape[0] == c.batch_size
|
|
assert mel_input.shape[0] == c.batch_size
|
|
assert mel_input.shape[2] == c.num_mels
|
|
|
|
def test_padding(self):
|
|
if ok_ljspeech:
|
|
dataset = LJSpeech.MyDataset(
|
|
os.path.join(c.data_path_LJSpeech),
|
|
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
|
1,
|
|
c.text_cleaner,
|
|
ap=self.ap,
|
|
min_seq_len=c.min_seq_len)
|
|
|
|
# Test for batch size 1
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
collate_fn=dataset.collate_fn,
|
|
drop_last=True,
|
|
num_workers=c.num_loader_workers)
|
|
|
|
for i, data in enumerate(dataloader):
|
|
if i == self.max_loader_iter:
|
|
break
|
|
text_input = data[0]
|
|
text_lengths = data[1]
|
|
linear_input = data[2]
|
|
mel_input = data[3]
|
|
mel_lengths = data[4]
|
|
stop_target = data[5]
|
|
item_idx = data[6]
|
|
|
|
# check the last time step to be zero padded
|
|
assert mel_input[0, -1].sum() == 0
|
|
assert mel_input[0, -2].sum() != 0
|
|
assert linear_input[0, -1].sum() == 0
|
|
assert linear_input[0, -2].sum() != 0
|
|
assert stop_target[0, -1] == 1
|
|
assert stop_target[0, -2] == 0
|
|
assert stop_target.sum() == 1
|
|
assert len(mel_lengths.shape) == 1
|
|
assert mel_lengths[0] == mel_input[0].shape[0]
|
|
|
|
# Test for batch size 2
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=2,
|
|
shuffle=False,
|
|
collate_fn=dataset.collate_fn,
|
|
drop_last=False,
|
|
num_workers=c.num_loader_workers)
|
|
|
|
for i, data in enumerate(dataloader):
|
|
if i == self.max_loader_iter:
|
|
break
|
|
text_input = data[0]
|
|
text_lengths = data[1]
|
|
linear_input = data[2]
|
|
mel_input = data[3]
|
|
mel_lengths = data[4]
|
|
stop_target = data[5]
|
|
item_idx = data[6]
|
|
|
|
if mel_lengths[0] > mel_lengths[1]:
|
|
idx = 0
|
|
else:
|
|
idx = 1
|
|
|
|
# check the first item in the batch
|
|
assert mel_input[idx, -1].sum() == 0
|
|
assert mel_input[idx, -2].sum() != 0, mel_input
|
|
assert linear_input[idx, -1].sum() == 0
|
|
assert linear_input[idx, -2].sum() != 0
|
|
assert stop_target[idx, -1] == 1
|
|
assert stop_target[idx, -2] == 0
|
|
assert stop_target[idx].sum() == 1
|
|
assert len(mel_lengths.shape) == 1
|
|
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
|
|
|
# check the second itme in the batch
|
|
assert mel_input[1 - idx, -1].sum() == 0
|
|
assert linear_input[1 - idx, -1].sum() == 0
|
|
assert stop_target[1 - idx, -1] == 1
|
|
assert len(mel_lengths.shape) == 1
|
|
|
|
# check batch conditions
|
|
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
|
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
|
|
|
|
|
class TestKusalDataset(unittest.TestCase):
|
|
def __init__(self, *args, **kwargs):
|
|
super(TestKusalDataset, self).__init__(*args, **kwargs)
|
|
self.max_loader_iter = 4
|
|
self.ap = AudioProcessor(
|
|
sample_rate=c.sample_rate,
|
|
num_mels=c.num_mels,
|
|
min_level_db=c.min_level_db,
|
|
frame_shift_ms=c.frame_shift_ms,
|
|
frame_length_ms=c.frame_length_ms,
|
|
ref_level_db=c.ref_level_db,
|
|
num_freq=c.num_freq,
|
|
power=c.power,
|
|
preemphasis=c.preemphasis)
|
|
|
|
def test_loader(self):
|
|
if ok_kusal:
|
|
dataset = Kusal.MyDataset(
|
|
os.path.join(c.data_path_Kusal),
|
|
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
|
c.r,
|
|
c.text_cleaner,
|
|
ap=self.ap,
|
|
min_seq_len=c.min_seq_len)
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=2,
|
|
shuffle=True,
|
|
collate_fn=dataset.collate_fn,
|
|
drop_last=True,
|
|
num_workers=c.num_loader_workers)
|
|
|
|
for i, data in enumerate(dataloader):
|
|
if i == self.max_loader_iter:
|
|
break
|
|
text_input = data[0]
|
|
text_lengths = data[1]
|
|
linear_input = data[2]
|
|
mel_input = data[3]
|
|
mel_lengths = data[4]
|
|
stop_target = data[5]
|
|
item_idx = data[6]
|
|
|
|
neg_values = text_input[text_input < 0]
|
|
check_count = len(neg_values)
|
|
assert check_count == 0, \
|
|
" !! Negative values in text_input: {}".format(check_count)
|
|
# TODO: more assertion here
|
|
assert linear_input.shape[0] == c.batch_size
|
|
assert mel_input.shape[0] == c.batch_size
|
|
assert mel_input.shape[2] == c.num_mels
|
|
|
|
def test_padding(self):
|
|
if ok_kusal:
|
|
dataset = Kusal.MyDataset(
|
|
os.path.join(c.data_path_Kusal),
|
|
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
|
1,
|
|
c.text_cleaner,
|
|
ap=self.ap,
|
|
min_seq_len=c.min_seq_len)
|
|
|
|
# Test for batch size 1
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
collate_fn=dataset.collate_fn,
|
|
drop_last=True,
|
|
num_workers=c.num_loader_workers)
|
|
|
|
for i, data in enumerate(dataloader):
|
|
if i == self.max_loader_iter:
|
|
break
|
|
text_input = data[0]
|
|
text_lengths = data[1]
|
|
linear_input = data[2]
|
|
mel_input = data[3]
|
|
mel_lengths = data[4]
|
|
stop_target = data[5]
|
|
item_idx = data[6]
|
|
|
|
# check the last time step to be zero padded
|
|
assert mel_input[0, -1].sum() == 0
|
|
# assert mel_input[0, -2].sum() != 0
|
|
assert linear_input[0, -1].sum() == 0
|
|
# assert linear_input[0, -2].sum() != 0
|
|
assert stop_target[0, -1] == 1
|
|
assert stop_target[0, -2] == 0
|
|
assert stop_target.sum() == 1
|
|
assert len(mel_lengths.shape) == 1
|
|
assert mel_lengths[0] == mel_input[0].shape[0]
|
|
|
|
# Test for batch size 2
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=2,
|
|
shuffle=False,
|
|
collate_fn=dataset.collate_fn,
|
|
drop_last=False,
|
|
num_workers=c.num_loader_workers)
|
|
|
|
for i, data in enumerate(dataloader):
|
|
if i == self.max_loader_iter:
|
|
break
|
|
text_input = data[0]
|
|
text_lengths = data[1]
|
|
linear_input = data[2]
|
|
mel_input = data[3]
|
|
mel_lengths = data[4]
|
|
stop_target = data[5]
|
|
item_idx = data[6]
|
|
|
|
if mel_lengths[0] > mel_lengths[1]:
|
|
idx = 0
|
|
else:
|
|
idx = 1
|
|
|
|
# check the first item in the batch
|
|
assert mel_input[idx, -1].sum() == 0
|
|
assert mel_input[idx, -2].sum() != 0, mel_input
|
|
assert linear_input[idx, -1].sum() == 0
|
|
assert linear_input[idx, -2].sum() != 0
|
|
assert stop_target[idx, -1] == 1
|
|
assert stop_target[idx, -2] == 0
|
|
assert stop_target[idx].sum() == 1
|
|
assert len(mel_lengths.shape) == 1
|
|
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
|
|
|
# check the second itme in the batch
|
|
assert mel_input[1 - idx, -1].sum() == 0
|
|
assert linear_input[1 - idx, -1].sum() == 0
|
|
assert stop_target[1 - idx, -1] == 1
|
|
assert len(mel_lengths.shape) == 1
|
|
|
|
# check batch conditions
|
|
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
|
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|