test bug fix

pull/10/head
Eren Golge 2019-03-12 09:52:01 +01:00
parent b9b79fcf0f
commit 65ffbae23d
2 changed files with 2 additions and 1 deletions

View File

@ -66,4 +66,4 @@ class TacotronTrainTest(unittest.TestCase):
assert (param != param_ref).any(
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1
count += 1

View File

@ -22,6 +22,7 @@ class TacotronTrainTest(unittest.TestCase):
def test_train_step(self):
input = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)