Fix return outputs

pull/1324/head
Eren Gölge 2022-02-21 09:57:57 +01:00
parent 424d04e4f6
commit 14c117978d
2 changed files with 4 additions and 2 deletions

View File

@ -1078,7 +1078,7 @@ class Vits(BaseTTS):
scores_disc_real,
scores_disc_fake,
)
return {}, loss_dict
return outputs, loss_dict
if optimizer_idx == 1:
mel = batch["mel"]

View File

@ -410,7 +410,9 @@ class TestVits(unittest.TestCase):
for _ in range(5):
batch = self._create_batch(config, 2)
for idx in [0, 1]:
_, loss_dict = model.train_step(batch, criterions, idx)
outputs, loss_dict = model.train_step(batch, criterions, idx)
self.assertFalse(not outputs)
self.assertFalse(not loss_dict)
loss_dict["loss"].backward()
optimizers[idx].step()
optimizers[idx].zero_grad()