diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index ec6c9e5b..02542f71 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1078,7 +1078,7 @@ class Vits(BaseTTS): scores_disc_real, scores_disc_fake, ) - return {}, loss_dict + return outputs, loss_dict if optimizer_idx == 1: mel = batch["mel"] diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 204ff2f7..384234e5 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -410,7 +410,9 @@ class TestVits(unittest.TestCase): for _ in range(5): batch = self._create_batch(config, 2) for idx in [0, 1]: - _, loss_dict = model.train_step(batch, criterions, idx) + outputs, loss_dict = model.train_step(batch, criterions, idx) + self.assertFalse(not outputs) + self.assertFalse(not loss_dict) loss_dict["loss"].backward() optimizers[idx].step() optimizers[idx].zero_grad()