fix wavegrad test

pull/10/head
erogol 2020-11-17 14:15:14 +01:00
parent a2a142dc39
commit 79ed5debcd
1 changed files with 3 additions and 1 deletions

View File

@ -1,5 +1,6 @@
import unittest
import numpy as np
import torch
from torch import optim
from TTS.vocoder.models.wavegrad import Wavegrad
@ -33,7 +34,8 @@ class WavegradTrainTest(unittest.TestCase):
[1, 2, 4, 8]])
model.train()
model.to(device)
model.compute_noise_level(1000, 1e-6, 1e-2)
betas = np.linspace(1e-6, 1e-2, 1000)
model.compute_noise_level(betas)
model_ref.load_state_dict(model.state_dict())
model_ref.to(device)
count = 0