From 65ffbae23d0e8f73479f4d4fe49b68a06ce503f8 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 12 Mar 2019 09:52:01 +0100 Subject: [PATCH] test bug fix --- tests/tacotron2_tests.py | 2 +- tests/tacotron_tests.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tacotron2_tests.py b/tests/tacotron2_tests.py index 56c5a1a1..c2f212f9 100644 --- a/tests/tacotron2_tests.py +++ b/tests/tacotron2_tests.py @@ -66,4 +66,4 @@ class TacotronTrainTest(unittest.TestCase): assert (param != param_ref).any( ), "param {} with shape {} not updated!! \n{}\n{}".format( count, param.shape, param, param_ref) - count += 1 \ No newline at end of file + count += 1 diff --git a/tests/tacotron_tests.py b/tests/tacotron_tests.py index 2f76469a..77195594 100644 --- a/tests/tacotron_tests.py +++ b/tests/tacotron_tests.py @@ -22,6 +22,7 @@ class TacotronTrainTest(unittest.TestCase): def test_train_step(self): input = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths[-1] = 128 mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device)