mirror of https://github.com/coqui-ai/TTS.git
test bug fix
parent
b9b79fcf0f
commit
65ffbae23d
|
@ -66,4 +66,4 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
assert (param != param_ref).any(
|
||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref)
|
||||
count += 1
|
||||
count += 1
|
||||
|
|
|
@ -22,6 +22,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
def test_train_step(self):
|
||||
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
|
|
Loading…
Reference in New Issue