diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index e3ed8ae2..72b47d23 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -37,6 +37,7 @@ class TacotronTrainTest(unittest.TestCase): mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device) mel_lengths = torch.randint(20, 30, (8,)).long().to(device) + mel_lengths[-1] = mel_spec.size(1) stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_ids = torch.randint(0, 5, (8,)).long().to(device) @@ -96,6 +97,7 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase): mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device) mel_lengths = torch.randint(20, 30, (8,)).long().to(device) + mel_lengths[-1] = mel_spec.size(1) stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_embeddings = torch.rand(8, 55).to(device)