mirror of https://github.com/coqui-ai/TTS.git
pytorch 0.4.1update
parent
90b96f9bed
commit
a15b3ec9a1
|
@ -24,8 +24,7 @@ class BahdanauAttention(nn.Module):
|
|||
processed_query = self.query_layer(query)
|
||||
processed_annots = self.annot_layer(annots)
|
||||
# (batch, max_time, 1)
|
||||
alignment = self.v(
|
||||
nn.functional.tanh(processed_query + processed_annots))
|
||||
alignment = self.v(torch.tanh(processed_query + processed_annots))
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
@ -72,8 +71,7 @@ class LocationSensitiveAttention(nn.Module):
|
|||
processed_query = self.query_layer(query)
|
||||
processed_annots = self.annot_layer(annot)
|
||||
alignment = self.v(
|
||||
nn.functional.tanh(processed_query + processed_annots +
|
||||
processed_loc))
|
||||
torch.tanh(processed_query + processed_annots + processed_loc))
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
|
|
@ -22,24 +22,13 @@ class L1LossMasked(nn.Module):
|
|||
Returns:
|
||||
loss: An average loss value masked by the length.
|
||||
"""
|
||||
input = input.contiguous()
|
||||
target = target.contiguous()
|
||||
|
||||
# logits_flat: (batch * max_len, dim)
|
||||
input = input.view(-1, input.shape[-1])
|
||||
# target_flat: (batch * max_len, dim)
|
||||
target_flat = target.view(-1, target.shape[-1])
|
||||
# losses_flat: (batch * max_len, dim)
|
||||
losses_flat = functional.l1_loss(
|
||||
input, target_flat, size_average=False, reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = sequence_mask(
|
||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||
losses = losses * mask.float()
|
||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
||||
mask = mask.expand_as(input)
|
||||
loss = functional.l1_loss(
|
||||
input * mask, target * mask, reduction="sum")
|
||||
loss = loss / mask.sum()
|
||||
return loss
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
numpy==1.14.3
|
||||
lws
|
||||
torch>=0.4.0
|
||||
torch>=0.4.1
|
||||
librosa==0.5.1
|
||||
Unidecode==0.4.20
|
||||
tensorboard
|
||||
|
|
2
setup.py
2
setup.py
|
@ -73,7 +73,7 @@ setup(
|
|||
setup_requires=["numpy==1.14.3"],
|
||||
install_requires=[
|
||||
"scipy==0.19.0",
|
||||
"torch == 0.4.0",
|
||||
"torch >= 0.4.1",
|
||||
"librosa==0.5.1",
|
||||
"unidecode==0.4.20",
|
||||
"tensorboardX",
|
||||
|
|
|
@ -29,127 +29,129 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
max_mel_freq=c.max_mel_freq)
|
||||
|
||||
def test_loader(self):
|
||||
dataset = LJSpeech.MyDataset(
|
||||
os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
if ok_ljspeech:
|
||||
dataset = LJSpeech.MyDataset(
|
||||
os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.num_mels
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.num_mels
|
||||
|
||||
def test_padding(self):
|
||||
dataset = LJSpeech.MyDataset(
|
||||
os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
1,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
if ok_ljspeech:
|
||||
dataset = LJSpeech.MyDataset(
|
||||
os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
1,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
# Test for batch size 1
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
# Test for batch size 1
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
# check the last time step to be zero padded
|
||||
assert mel_input[0, -1].sum() == 0
|
||||
assert mel_input[0, -2].sum() != 0
|
||||
assert linear_input[0, -1].sum() == 0
|
||||
assert linear_input[0, -2].sum() != 0
|
||||
assert stop_target[0, -1] == 1
|
||||
assert stop_target[0, -2] == 0
|
||||
assert stop_target.sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
# check the last time step to be zero padded
|
||||
assert mel_input[0, -1].sum() == 0
|
||||
assert mel_input[0, -2].sum() != 0
|
||||
assert linear_input[0, -1].sum() == 0
|
||||
assert linear_input[0, -2].sum() != 0
|
||||
assert stop_target[0, -1] == 1
|
||||
assert stop_target[0, -2] == 0
|
||||
assert stop_target.sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
|
||||
# Test for batch size 2
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
num_workers=c.num_loader_workers)
|
||||
# Test for batch size 2
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
if mel_lengths[0] > mel_lengths[1]:
|
||||
idx = 0
|
||||
else:
|
||||
idx = 1
|
||||
if mel_lengths[0] > mel_lengths[1]:
|
||||
idx = 0
|
||||
else:
|
||||
idx = 1
|
||||
|
||||
# check the first item in the batch
|
||||
assert mel_input[idx, -1].sum() == 0
|
||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||
assert linear_input[idx, -1].sum() == 0
|
||||
assert linear_input[idx, -2].sum() != 0
|
||||
assert stop_target[idx, -1] == 1
|
||||
assert stop_target[idx, -2] == 0
|
||||
assert stop_target[idx].sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
# check the first item in the batch
|
||||
assert mel_input[idx, -1].sum() == 0
|
||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||
assert linear_input[idx, -1].sum() == 0
|
||||
assert linear_input[idx, -2].sum() != 0
|
||||
assert stop_target[idx, -1] == 1
|
||||
assert stop_target[idx, -2] == 0
|
||||
assert stop_target[idx].sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
|
||||
# check the second itme in the batch
|
||||
assert mel_input[1 - idx, -1].sum() == 0
|
||||
assert linear_input[1 - idx, -1].sum() == 0
|
||||
assert stop_target[1 - idx, -1] == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
# check the second itme in the batch
|
||||
assert mel_input[1 - idx, -1].sum() == 0
|
||||
assert linear_input[1 - idx, -1].sum() == 0
|
||||
assert stop_target[1 - idx, -1] == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
|
||||
# check batch conditions
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
# check batch conditions
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
||||
|
||||
class TestKusalDataset(unittest.TestCase):
|
||||
|
@ -170,257 +172,126 @@ class TestKusalDataset(unittest.TestCase):
|
|||
max_mel_freq=c.max_mel_freq)
|
||||
|
||||
def test_loader(self):
|
||||
dataset = Kusal.MyDataset(
|
||||
os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
if ok_kusal:
|
||||
dataset = Kusal.MyDataset(
|
||||
os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.num_mels
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.num_mels
|
||||
|
||||
def test_padding(self):
|
||||
dataset = Kusal.MyDataset(
|
||||
os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
1,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
if ok_kusal:
|
||||
dataset = Kusal.MyDataset(
|
||||
os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
1,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
# Test for batch size 1
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
# Test for batch size 1
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
# check the last time step to be zero padded
|
||||
assert mel_input[0, -1].sum() == 0
|
||||
# assert mel_input[0, -2].sum() != 0
|
||||
assert linear_input[0, -1].sum() == 0
|
||||
# assert linear_input[0, -2].sum() != 0
|
||||
assert stop_target[0, -1] == 1
|
||||
assert stop_target[0, -2] == 0
|
||||
assert stop_target.sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
# check the last time step to be zero padded
|
||||
assert mel_input[0, -1].sum() == 0
|
||||
# assert mel_input[0, -2].sum() != 0
|
||||
assert linear_input[0, -1].sum() == 0
|
||||
# assert linear_input[0, -2].sum() != 0
|
||||
assert stop_target[0, -1] == 1
|
||||
assert stop_target[0, -2] == 0
|
||||
assert stop_target.sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
|
||||
# Test for batch size 2
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
num_workers=c.num_loader_workers)
|
||||
# Test for batch size 2
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
if mel_lengths[0] > mel_lengths[1]:
|
||||
idx = 0
|
||||
else:
|
||||
idx = 1
|
||||
if mel_lengths[0] > mel_lengths[1]:
|
||||
idx = 0
|
||||
else:
|
||||
idx = 1
|
||||
|
||||
# check the first item in the batch
|
||||
assert mel_input[idx, -1].sum() == 0
|
||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||
assert linear_input[idx, -1].sum() == 0
|
||||
assert linear_input[idx, -2].sum() != 0
|
||||
assert stop_target[idx, -1] == 1
|
||||
assert stop_target[idx, -2] == 0
|
||||
assert stop_target[idx].sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
# check the first item in the batch
|
||||
assert mel_input[idx, -1].sum() == 0
|
||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||
assert linear_input[idx, -1].sum() == 0
|
||||
assert linear_input[idx, -2].sum() != 0
|
||||
assert stop_target[idx, -1] == 1
|
||||
assert stop_target[idx, -2] == 0
|
||||
assert stop_target[idx].sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
|
||||
# check the second itme in the batch
|
||||
assert mel_input[1 - idx, -1].sum() == 0
|
||||
assert linear_input[1 - idx, -1].sum() == 0
|
||||
assert stop_target[1 - idx, -1] == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
# check the second itme in the batch
|
||||
assert mel_input[1 - idx, -1].sum() == 0
|
||||
assert linear_input[1 - idx, -1].sum() == 0
|
||||
assert stop_target[1 - idx, -1] == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
|
||||
# check batch conditions
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
||||
|
||||
# class TestTWEBDataset(unittest.TestCase):
|
||||
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super(TestTWEBDataset, self).__init__(*args, **kwargs)
|
||||
# self.max_loader_iter = 4
|
||||
|
||||
# def test_loader(self):
|
||||
# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
|
||||
# os.path.join(c.data_path_TWEB, 'wavs'),
|
||||
# c.r,
|
||||
# c.sample_rate,
|
||||
# c.text_cleaner,
|
||||
# c.num_mels,
|
||||
# c.min_level_db,
|
||||
# c.frame_shift_ms,
|
||||
# c.frame_length_ms,
|
||||
# c.preemphasis,
|
||||
# c.ref_level_db,
|
||||
# c.num_freq,
|
||||
# c.power
|
||||
# )
|
||||
|
||||
# dataloader = DataLoader(dataset, batch_size=2,
|
||||
# shuffle=True, collate_fn=dataset.collate_fn,
|
||||
# drop_last=True, num_workers=c.num_loader_workers)
|
||||
|
||||
# for i, data in enumerate(dataloader):
|
||||
# if i == self.max_loader_iter:
|
||||
# break
|
||||
# text_input = data[0]
|
||||
# text_lengths = data[1]
|
||||
# linear_input = data[2]
|
||||
# mel_input = data[3]
|
||||
# mel_lengths = data[4]
|
||||
# stop_target = data[5]
|
||||
# item_idx = data[6]
|
||||
|
||||
# neg_values = text_input[text_input < 0]
|
||||
# check_count = len(neg_values)
|
||||
# assert check_count == 0, \
|
||||
# " !! Negative values in text_input: {}".format(check_count)
|
||||
# # TODO: more assertion here
|
||||
# assert linear_input.shape[0] == c.batch_size
|
||||
# assert mel_input.shape[0] == c.batch_size
|
||||
# assert mel_input.shape[2] == c.num_mels
|
||||
|
||||
# def test_padding(self):
|
||||
# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
|
||||
# os.path.join(c.data_path_TWEB, 'wavs'),
|
||||
# 1,
|
||||
# c.sample_rate,
|
||||
# c.text_cleaner,
|
||||
# c.num_mels,
|
||||
# c.min_level_db,
|
||||
# c.frame_shift_ms,
|
||||
# c.frame_length_ms,
|
||||
# c.preemphasis,
|
||||
# c.ref_level_db,
|
||||
# c.num_freq,
|
||||
# c.power
|
||||
# )
|
||||
|
||||
# # Test for batch size 1
|
||||
# dataloader = DataLoader(dataset, batch_size=1,
|
||||
# shuffle=False, collate_fn=dataset.collate_fn,
|
||||
# drop_last=False, num_workers=c.num_loader_workers)
|
||||
|
||||
# for i, data in enumerate(dataloader):
|
||||
# if i == self.max_loader_iter:
|
||||
# break
|
||||
|
||||
# text_input = data[0]
|
||||
# text_lengths = data[1]
|
||||
# linear_input = data[2]
|
||||
# mel_input = data[3]
|
||||
# mel_lengths = data[4]
|
||||
# stop_target = data[5]
|
||||
# item_idx = data[6]
|
||||
|
||||
# # check the last time step to be zero padded
|
||||
# assert mel_input[0, -1].sum() == 0
|
||||
# assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i)
|
||||
# assert linear_input[0, -1].sum() == 0
|
||||
# assert linear_input[0, -2].sum() != 0
|
||||
# assert stop_target[0, -1] == 1
|
||||
# assert stop_target[0, -2] == 0
|
||||
# assert stop_target.sum() == 1
|
||||
# assert len(mel_lengths.shape) == 1
|
||||
# assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
|
||||
# # Test for batch size 2
|
||||
# dataloader = DataLoader(dataset, batch_size=2,
|
||||
# shuffle=False, collate_fn=dataset.collate_fn,
|
||||
# drop_last=False, num_workers=c.num_loader_workers)
|
||||
|
||||
# for i, data in enumerate(dataloader):
|
||||
# if i == self.max_loader_iter:
|
||||
# break
|
||||
# text_input = data[0]
|
||||
# text_lengths = data[1]
|
||||
# linear_input = data[2]
|
||||
# mel_input = data[3]
|
||||
# mel_lengths = data[4]
|
||||
# stop_target = data[5]
|
||||
# item_idx = data[6]
|
||||
|
||||
# if mel_lengths[0] > mel_lengths[1]:
|
||||
# idx = 0
|
||||
# else:
|
||||
# idx = 1
|
||||
|
||||
# # check the first item in the batch
|
||||
# assert mel_input[idx, -1].sum() == 0
|
||||
# assert mel_input[idx, -2].sum() != 0, mel_input
|
||||
# assert linear_input[idx, -1].sum() == 0
|
||||
# assert linear_input[idx, -2].sum() != 0
|
||||
# assert stop_target[idx, -1] == 1
|
||||
# assert stop_target[idx, -2] == 0
|
||||
# assert stop_target[idx].sum() == 1
|
||||
# assert len(mel_lengths.shape) == 1
|
||||
# assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
|
||||
# # check the second itme in the batch
|
||||
# assert mel_input[1-idx, -1].sum() == 0
|
||||
# assert linear_input[1-idx, -1].sum() == 0
|
||||
# assert stop_target[1-idx, -1] == 1
|
||||
# assert len(mel_lengths.shape) == 1
|
||||
|
||||
# # check batch conditions
|
||||
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
# check batch conditions
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
|
Loading…
Reference in New Issue