pytorch 0.4.1update

pull/10/head
Eren 2018-08-13 15:02:17 +02:00
parent 90b96f9bed
commit a15b3ec9a1
5 changed files with 227 additions and 369 deletions

View File

@ -24,8 +24,7 @@ class BahdanauAttention(nn.Module):
processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annots)
# (batch, max_time, 1)
alignment = self.v(
nn.functional.tanh(processed_query + processed_annots))
alignment = self.v(torch.tanh(processed_query + processed_annots))
# (batch, max_time)
return alignment.squeeze(-1)
@ -72,8 +71,7 @@ class LocationSensitiveAttention(nn.Module):
processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annot)
alignment = self.v(
nn.functional.tanh(processed_query + processed_annots +
processed_loc))
torch.tanh(processed_query + processed_annots + processed_loc))
# (batch, max_time)
return alignment.squeeze(-1)

View File

@ -22,24 +22,13 @@ class L1LossMasked(nn.Module):
Returns:
loss: An average loss value masked by the length.
"""
input = input.contiguous()
target = target.contiguous()
# logits_flat: (batch * max_len, dim)
input = input.view(-1, input.shape[-1])
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(
input, target_flat, size_average=False, reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
mask = mask.expand_as(input)
loss = functional.l1_loss(
input * mask, target * mask, reduction="sum")
loss = loss / mask.sum()
return loss

View File

@ -1,6 +1,6 @@
numpy==1.14.3
lws
torch>=0.4.0
torch>=0.4.1
librosa==0.5.1
Unidecode==0.4.20
tensorboard

View File

@ -73,7 +73,7 @@ setup(
setup_requires=["numpy==1.14.3"],
install_requires=[
"scipy==0.19.0",
"torch == 0.4.0",
"torch >= 0.4.1",
"librosa==0.5.1",
"unidecode==0.4.20",
"tensorboardX",

View File

@ -29,127 +29,129 @@ class TestLJSpeechDataset(unittest.TestCase):
max_mel_freq=c.max_mel_freq)
def test_loader(self):
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.r,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
if ok_ljspeech:
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.r,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
def test_padding(self):
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
1,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
if ok_ljspeech:
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
1,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
# Test for batch size 1
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
# Test for batch size 1
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0
assert mel_input[0, -2].sum() != 0
assert linear_input[0, -1].sum() == 0
assert linear_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == mel_input[0].shape[0]
# check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0
assert mel_input[0, -2].sum() != 0
assert linear_input[0, -1].sum() == 0
assert linear_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
# Test for batch size 2
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# check the first item in the batch
assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the first item in the batch
assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch
assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check the second itme in the batch
assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
class TestKusalDataset(unittest.TestCase):
@ -170,257 +172,126 @@ class TestKusalDataset(unittest.TestCase):
max_mel_freq=c.max_mel_freq)
def test_loader(self):
dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
c.r,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
if ok_kusal:
dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
c.r,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
def test_padding(self):
dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
1,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
if ok_kusal:
dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
1,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
# Test for batch size 1
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
# Test for batch size 1
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0
# assert mel_input[0, -2].sum() != 0
assert linear_input[0, -1].sum() == 0
# assert linear_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == mel_input[0].shape[0]
# check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0
# assert mel_input[0, -2].sum() != 0
assert linear_input[0, -1].sum() == 0
# assert linear_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
# Test for batch size 2
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# check the first item in the batch
assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the first item in the batch
assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch
assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check the second itme in the batch
assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# class TestTWEBDataset(unittest.TestCase):
# def __init__(self, *args, **kwargs):
# super(TestTWEBDataset, self).__init__(*args, **kwargs)
# self.max_loader_iter = 4
# def test_loader(self):
# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
# os.path.join(c.data_path_TWEB, 'wavs'),
# c.r,
# c.sample_rate,
# c.text_cleaner,
# c.num_mels,
# c.min_level_db,
# c.frame_shift_ms,
# c.frame_length_ms,
# c.preemphasis,
# c.ref_level_db,
# c.num_freq,
# c.power
# )
# dataloader = DataLoader(dataset, batch_size=2,
# shuffle=True, collate_fn=dataset.collate_fn,
# drop_last=True, num_workers=c.num_loader_workers)
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# neg_values = text_input[text_input < 0]
# check_count = len(neg_values)
# assert check_count == 0, \
# " !! Negative values in text_input: {}".format(check_count)
# # TODO: more assertion here
# assert linear_input.shape[0] == c.batch_size
# assert mel_input.shape[0] == c.batch_size
# assert mel_input.shape[2] == c.num_mels
# def test_padding(self):
# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
# os.path.join(c.data_path_TWEB, 'wavs'),
# 1,
# c.sample_rate,
# c.text_cleaner,
# c.num_mels,
# c.min_level_db,
# c.frame_shift_ms,
# c.frame_length_ms,
# c.preemphasis,
# c.ref_level_db,
# c.num_freq,
# c.power
# )
# # Test for batch size 1
# dataloader = DataLoader(dataset, batch_size=1,
# shuffle=False, collate_fn=dataset.collate_fn,
# drop_last=False, num_workers=c.num_loader_workers)
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# # check the last time step to be zero padded
# assert mel_input[0, -1].sum() == 0
# assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i)
# assert linear_input[0, -1].sum() == 0
# assert linear_input[0, -2].sum() != 0
# assert stop_target[0, -1] == 1
# assert stop_target[0, -2] == 0
# assert stop_target.sum() == 1
# assert len(mel_lengths.shape) == 1
# assert mel_lengths[0] == mel_input[0].shape[0]
# # Test for batch size 2
# dataloader = DataLoader(dataset, batch_size=2,
# shuffle=False, collate_fn=dataset.collate_fn,
# drop_last=False, num_workers=c.num_loader_workers)
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# if mel_lengths[0] > mel_lengths[1]:
# idx = 0
# else:
# idx = 1
# # check the first item in the batch
# assert mel_input[idx, -1].sum() == 0
# assert mel_input[idx, -2].sum() != 0, mel_input
# assert linear_input[idx, -1].sum() == 0
# assert linear_input[idx, -2].sum() != 0
# assert stop_target[idx, -1] == 1
# assert stop_target[idx, -2] == 0
# assert stop_target[idx].sum() == 1
# assert len(mel_lengths.shape) == 1
# assert mel_lengths[idx] == mel_input[idx].shape[0]
# # check the second itme in the batch
# assert mel_input[1-idx, -1].sum() == 0
# assert linear_input[1-idx, -1].sum() == 0
# assert stop_target[1-idx, -1] == 1
# assert len(mel_lengths.shape) == 1
# # check batch conditions
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0