mirror of https://github.com/coqui-ai/TTS.git
Fix return outputs
parent
424d04e4f6
commit
14c117978d
|
@ -1078,7 +1078,7 @@ class Vits(BaseTTS):
|
|||
scores_disc_real,
|
||||
scores_disc_fake,
|
||||
)
|
||||
return {}, loss_dict
|
||||
return outputs, loss_dict
|
||||
|
||||
if optimizer_idx == 1:
|
||||
mel = batch["mel"]
|
||||
|
|
|
@ -410,7 +410,9 @@ class TestVits(unittest.TestCase):
|
|||
for _ in range(5):
|
||||
batch = self._create_batch(config, 2)
|
||||
for idx in [0, 1]:
|
||||
_, loss_dict = model.train_step(batch, criterions, idx)
|
||||
outputs, loss_dict = model.train_step(batch, criterions, idx)
|
||||
self.assertFalse(not outputs)
|
||||
self.assertFalse(not loss_dict)
|
||||
loss_dict["loss"].backward()
|
||||
optimizers[idx].step()
|
||||
optimizers[idx].zero_grad()
|
||||
|
|
Loading…
Reference in New Issue