diff --git a/layers/attention.py b/layers/attention.py index 4f8af178..26e6c5d2 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -24,8 +24,7 @@ class BahdanauAttention(nn.Module): processed_query = self.query_layer(query) processed_annots = self.annot_layer(annots) # (batch, max_time, 1) - alignment = self.v( - nn.functional.tanh(processed_query + processed_annots)) + alignment = self.v(torch.tanh(processed_query + processed_annots)) # (batch, max_time) return alignment.squeeze(-1) @@ -72,8 +71,7 @@ class LocationSensitiveAttention(nn.Module): processed_query = self.query_layer(query) processed_annots = self.annot_layer(annot) alignment = self.v( - nn.functional.tanh(processed_query + processed_annots + - processed_loc)) + torch.tanh(processed_query + processed_annots + processed_loc)) # (batch, max_time) return alignment.squeeze(-1) diff --git a/layers/losses.py b/layers/losses.py index acf4e789..4be86424 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -22,24 +22,13 @@ class L1LossMasked(nn.Module): Returns: loss: An average loss value masked by the length. """ - input = input.contiguous() - target = target.contiguous() - - # logits_flat: (batch * max_len, dim) - input = input.view(-1, input.shape[-1]) - # target_flat: (batch * max_len, dim) - target_flat = target.view(-1, target.shape[-1]) - # losses_flat: (batch * max_len, dim) - losses_flat = functional.l1_loss( - input, target_flat, size_average=False, reduce=False) - # losses: (batch, max_len, dim) - losses = losses_flat.view(*target.size()) - # mask: (batch, max_len, 1) mask = sequence_mask( - sequence_length=length, max_len=target.size(1)).unsqueeze(2) - losses = losses * mask.float() - loss = losses.sum() / (length.float().sum() * float(target.shape[2])) + sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() + mask = mask.expand_as(input) + loss = functional.l1_loss( + input * mask, target * mask, reduction="sum") + loss = loss / mask.sum() return loss diff --git a/requirements.txt b/requirements.txt index 5f9a9169..73e5dae7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy==1.14.3 lws -torch>=0.4.0 +torch>=0.4.1 librosa==0.5.1 Unidecode==0.4.20 tensorboard diff --git a/setup.py b/setup.py index b2b4b1ed..29d4c632 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ setup( setup_requires=["numpy==1.14.3"], install_requires=[ "scipy==0.19.0", - "torch == 0.4.0", + "torch >= 0.4.1", "librosa==0.5.1", "unidecode==0.4.20", "tensorboardX", diff --git a/tests/loader_tests.py b/tests/loader_tests.py index 5d5cbe52..695861bb 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -29,127 +29,129 @@ class TestLJSpeechDataset(unittest.TestCase): max_mel_freq=c.max_mel_freq) def test_loader(self): - dataset = LJSpeech.MyDataset( - os.path.join(c.data_path_LJSpeech), - os.path.join(c.data_path_LJSpeech, 'metadata.csv'), - c.r, - c.text_cleaner, - ap=self.ap, - min_seq_len=c.min_seq_len) + if ok_ljspeech: + dataset = LJSpeech.MyDataset( + os.path.join(c.data_path_LJSpeech), + os.path.join(c.data_path_LJSpeech, 'metadata.csv'), + c.r, + c.text_cleaner, + ap=self.ap, + min_seq_len=c.min_seq_len) - dataloader = DataLoader( - dataset, - batch_size=2, - shuffle=True, - collate_fn=dataset.collate_fn, - drop_last=True, - num_workers=c.num_loader_workers) + dataloader = DataLoader( + dataset, + batch_size=2, + shuffle=True, + collate_fn=dataset.collate_fn, + drop_last=True, + num_workers=c.num_loader_workers) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] - neg_values = text_input[text_input < 0] - check_count = len(neg_values) - assert check_count == 0, \ - " !! Negative values in text_input: {}".format(check_count) - # TODO: more assertion here - assert linear_input.shape[0] == c.batch_size - assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.num_mels + neg_values = text_input[text_input < 0] + check_count = len(neg_values) + assert check_count == 0, \ + " !! Negative values in text_input: {}".format(check_count) + # TODO: more assertion here + assert linear_input.shape[0] == c.batch_size + assert mel_input.shape[0] == c.batch_size + assert mel_input.shape[2] == c.num_mels def test_padding(self): - dataset = LJSpeech.MyDataset( - os.path.join(c.data_path_LJSpeech), - os.path.join(c.data_path_LJSpeech, 'metadata.csv'), - 1, - c.text_cleaner, - ap=self.ap, - min_seq_len=c.min_seq_len) + if ok_ljspeech: + dataset = LJSpeech.MyDataset( + os.path.join(c.data_path_LJSpeech), + os.path.join(c.data_path_LJSpeech, 'metadata.csv'), + 1, + c.text_cleaner, + ap=self.ap, + min_seq_len=c.min_seq_len) - # Test for batch size 1 - dataloader = DataLoader( - dataset, - batch_size=1, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=True, - num_workers=c.num_loader_workers) + # Test for batch size 1 + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=True, + num_workers=c.num_loader_workers) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] - # check the last time step to be zero padded - assert mel_input[0, -1].sum() == 0 - assert mel_input[0, -2].sum() != 0 - assert linear_input[0, -1].sum() == 0 - assert linear_input[0, -2].sum() != 0 - assert stop_target[0, -1] == 1 - assert stop_target[0, -2] == 0 - assert stop_target.sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[0] == mel_input[0].shape[0] + # check the last time step to be zero padded + assert mel_input[0, -1].sum() == 0 + assert mel_input[0, -2].sum() != 0 + assert linear_input[0, -1].sum() == 0 + assert linear_input[0, -2].sum() != 0 + assert stop_target[0, -1] == 1 + assert stop_target[0, -2] == 0 + assert stop_target.sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[0] == mel_input[0].shape[0] - # Test for batch size 2 - dataloader = DataLoader( - dataset, - batch_size=2, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=False, - num_workers=c.num_loader_workers) + # Test for batch size 2 + dataloader = DataLoader( + dataset, + batch_size=2, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + num_workers=c.num_loader_workers) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] - if mel_lengths[0] > mel_lengths[1]: - idx = 0 - else: - idx = 1 + if mel_lengths[0] > mel_lengths[1]: + idx = 0 + else: + idx = 1 - # check the first item in the batch - assert mel_input[idx, -1].sum() == 0 - assert mel_input[idx, -2].sum() != 0, mel_input - assert linear_input[idx, -1].sum() == 0 - assert linear_input[idx, -2].sum() != 0 - assert stop_target[idx, -1] == 1 - assert stop_target[idx, -2] == 0 - assert stop_target[idx].sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[idx] == mel_input[idx].shape[0] + # check the first item in the batch + assert mel_input[idx, -1].sum() == 0 + assert mel_input[idx, -2].sum() != 0, mel_input + assert linear_input[idx, -1].sum() == 0 + assert linear_input[idx, -2].sum() != 0 + assert stop_target[idx, -1] == 1 + assert stop_target[idx, -2] == 0 + assert stop_target[idx].sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[idx] == mel_input[idx].shape[0] - # check the second itme in the batch - assert mel_input[1 - idx, -1].sum() == 0 - assert linear_input[1 - idx, -1].sum() == 0 - assert stop_target[1 - idx, -1] == 1 - assert len(mel_lengths.shape) == 1 + # check the second itme in the batch + assert mel_input[1 - idx, -1].sum() == 0 + assert linear_input[1 - idx, -1].sum() == 0 + assert stop_target[1 - idx, -1] == 1 + assert len(mel_lengths.shape) == 1 - # check batch conditions - assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 - assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 + # check batch conditions + assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 + assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 class TestKusalDataset(unittest.TestCase): @@ -170,257 +172,126 @@ class TestKusalDataset(unittest.TestCase): max_mel_freq=c.max_mel_freq) def test_loader(self): - dataset = Kusal.MyDataset( - os.path.join(c.data_path_Kusal), - os.path.join(c.data_path_Kusal, 'prompts.txt'), - c.r, - c.text_cleaner, - ap=self.ap, - min_seq_len=c.min_seq_len) + if ok_kusal: + dataset = Kusal.MyDataset( + os.path.join(c.data_path_Kusal), + os.path.join(c.data_path_Kusal, 'prompts.txt'), + c.r, + c.text_cleaner, + ap=self.ap, + min_seq_len=c.min_seq_len) - dataloader = DataLoader( - dataset, - batch_size=2, - shuffle=True, - collate_fn=dataset.collate_fn, - drop_last=True, - num_workers=c.num_loader_workers) + dataloader = DataLoader( + dataset, + batch_size=2, + shuffle=True, + collate_fn=dataset.collate_fn, + drop_last=True, + num_workers=c.num_loader_workers) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] - neg_values = text_input[text_input < 0] - check_count = len(neg_values) - assert check_count == 0, \ - " !! Negative values in text_input: {}".format(check_count) - # TODO: more assertion here - assert linear_input.shape[0] == c.batch_size - assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.num_mels + neg_values = text_input[text_input < 0] + check_count = len(neg_values) + assert check_count == 0, \ + " !! Negative values in text_input: {}".format(check_count) + # TODO: more assertion here + assert linear_input.shape[0] == c.batch_size + assert mel_input.shape[0] == c.batch_size + assert mel_input.shape[2] == c.num_mels def test_padding(self): - dataset = Kusal.MyDataset( - os.path.join(c.data_path_Kusal), - os.path.join(c.data_path_Kusal, 'prompts.txt'), - 1, - c.text_cleaner, - ap=self.ap, - min_seq_len=c.min_seq_len) + if ok_kusal: + dataset = Kusal.MyDataset( + os.path.join(c.data_path_Kusal), + os.path.join(c.data_path_Kusal, 'prompts.txt'), + 1, + c.text_cleaner, + ap=self.ap, + min_seq_len=c.min_seq_len) - # Test for batch size 1 - dataloader = DataLoader( - dataset, - batch_size=1, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=True, - num_workers=c.num_loader_workers) + # Test for batch size 1 + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=True, + num_workers=c.num_loader_workers) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] - # check the last time step to be zero padded - assert mel_input[0, -1].sum() == 0 - # assert mel_input[0, -2].sum() != 0 - assert linear_input[0, -1].sum() == 0 - # assert linear_input[0, -2].sum() != 0 - assert stop_target[0, -1] == 1 - assert stop_target[0, -2] == 0 - assert stop_target.sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[0] == mel_input[0].shape[0] + # check the last time step to be zero padded + assert mel_input[0, -1].sum() == 0 + # assert mel_input[0, -2].sum() != 0 + assert linear_input[0, -1].sum() == 0 + # assert linear_input[0, -2].sum() != 0 + assert stop_target[0, -1] == 1 + assert stop_target[0, -2] == 0 + assert stop_target.sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[0] == mel_input[0].shape[0] - # Test for batch size 2 - dataloader = DataLoader( - dataset, - batch_size=2, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=False, - num_workers=c.num_loader_workers) + # Test for batch size 2 + dataloader = DataLoader( + dataset, + batch_size=2, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + num_workers=c.num_loader_workers) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] - if mel_lengths[0] > mel_lengths[1]: - idx = 0 - else: - idx = 1 + if mel_lengths[0] > mel_lengths[1]: + idx = 0 + else: + idx = 1 - # check the first item in the batch - assert mel_input[idx, -1].sum() == 0 - assert mel_input[idx, -2].sum() != 0, mel_input - assert linear_input[idx, -1].sum() == 0 - assert linear_input[idx, -2].sum() != 0 - assert stop_target[idx, -1] == 1 - assert stop_target[idx, -2] == 0 - assert stop_target[idx].sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[idx] == mel_input[idx].shape[0] + # check the first item in the batch + assert mel_input[idx, -1].sum() == 0 + assert mel_input[idx, -2].sum() != 0, mel_input + assert linear_input[idx, -1].sum() == 0 + assert linear_input[idx, -2].sum() != 0 + assert stop_target[idx, -1] == 1 + assert stop_target[idx, -2] == 0 + assert stop_target[idx].sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[idx] == mel_input[idx].shape[0] - # check the second itme in the batch - assert mel_input[1 - idx, -1].sum() == 0 - assert linear_input[1 - idx, -1].sum() == 0 - assert stop_target[1 - idx, -1] == 1 - assert len(mel_lengths.shape) == 1 + # check the second itme in the batch + assert mel_input[1 - idx, -1].sum() == 0 + assert linear_input[1 - idx, -1].sum() == 0 + assert stop_target[1 - idx, -1] == 1 + assert len(mel_lengths.shape) == 1 - # check batch conditions - assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 - assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 - - -# class TestTWEBDataset(unittest.TestCase): - -# def __init__(self, *args, **kwargs): -# super(TestTWEBDataset, self).__init__(*args, **kwargs) -# self.max_loader_iter = 4 - -# def test_loader(self): -# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), -# os.path.join(c.data_path_TWEB, 'wavs'), -# c.r, -# c.sample_rate, -# c.text_cleaner, -# c.num_mels, -# c.min_level_db, -# c.frame_shift_ms, -# c.frame_length_ms, -# c.preemphasis, -# c.ref_level_db, -# c.num_freq, -# c.power -# ) - -# dataloader = DataLoader(dataset, batch_size=2, -# shuffle=True, collate_fn=dataset.collate_fn, -# drop_last=True, num_workers=c.num_loader_workers) - -# for i, data in enumerate(dataloader): -# if i == self.max_loader_iter: -# break -# text_input = data[0] -# text_lengths = data[1] -# linear_input = data[2] -# mel_input = data[3] -# mel_lengths = data[4] -# stop_target = data[5] -# item_idx = data[6] - -# neg_values = text_input[text_input < 0] -# check_count = len(neg_values) -# assert check_count == 0, \ -# " !! Negative values in text_input: {}".format(check_count) -# # TODO: more assertion here -# assert linear_input.shape[0] == c.batch_size -# assert mel_input.shape[0] == c.batch_size -# assert mel_input.shape[2] == c.num_mels - -# def test_padding(self): -# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'), -# os.path.join(c.data_path_TWEB, 'wavs'), -# 1, -# c.sample_rate, -# c.text_cleaner, -# c.num_mels, -# c.min_level_db, -# c.frame_shift_ms, -# c.frame_length_ms, -# c.preemphasis, -# c.ref_level_db, -# c.num_freq, -# c.power -# ) - -# # Test for batch size 1 -# dataloader = DataLoader(dataset, batch_size=1, -# shuffle=False, collate_fn=dataset.collate_fn, -# drop_last=False, num_workers=c.num_loader_workers) - -# for i, data in enumerate(dataloader): -# if i == self.max_loader_iter: -# break - -# text_input = data[0] -# text_lengths = data[1] -# linear_input = data[2] -# mel_input = data[3] -# mel_lengths = data[4] -# stop_target = data[5] -# item_idx = data[6] - -# # check the last time step to be zero padded -# assert mel_input[0, -1].sum() == 0 -# assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i) -# assert linear_input[0, -1].sum() == 0 -# assert linear_input[0, -2].sum() != 0 -# assert stop_target[0, -1] == 1 -# assert stop_target[0, -2] == 0 -# assert stop_target.sum() == 1 -# assert len(mel_lengths.shape) == 1 -# assert mel_lengths[0] == mel_input[0].shape[0] - -# # Test for batch size 2 -# dataloader = DataLoader(dataset, batch_size=2, -# shuffle=False, collate_fn=dataset.collate_fn, -# drop_last=False, num_workers=c.num_loader_workers) - -# for i, data in enumerate(dataloader): -# if i == self.max_loader_iter: -# break -# text_input = data[0] -# text_lengths = data[1] -# linear_input = data[2] -# mel_input = data[3] -# mel_lengths = data[4] -# stop_target = data[5] -# item_idx = data[6] - -# if mel_lengths[0] > mel_lengths[1]: -# idx = 0 -# else: -# idx = 1 - -# # check the first item in the batch -# assert mel_input[idx, -1].sum() == 0 -# assert mel_input[idx, -2].sum() != 0, mel_input -# assert linear_input[idx, -1].sum() == 0 -# assert linear_input[idx, -2].sum() != 0 -# assert stop_target[idx, -1] == 1 -# assert stop_target[idx, -2] == 0 -# assert stop_target[idx].sum() == 1 -# assert len(mel_lengths.shape) == 1 -# assert mel_lengths[idx] == mel_input[idx].shape[0] - -# # check the second itme in the batch -# assert mel_input[1-idx, -1].sum() == 0 -# assert linear_input[1-idx, -1].sum() == 0 -# assert stop_target[1-idx, -1] == 1 -# assert len(mel_lengths.shape) == 1 - -# # check batch conditions -# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 -# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 + # check batch conditions + assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 + assert (linear_input * stop_target.unsqueeze(2)).sum() == 0