mirror of https://github.com/coqui-ai/TTS.git
Fix return outputs
parent
424d04e4f6
commit
14c117978d
|
@ -1078,7 +1078,7 @@ class Vits(BaseTTS):
|
||||||
scores_disc_real,
|
scores_disc_real,
|
||||||
scores_disc_fake,
|
scores_disc_fake,
|
||||||
)
|
)
|
||||||
return {}, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
if optimizer_idx == 1:
|
if optimizer_idx == 1:
|
||||||
mel = batch["mel"]
|
mel = batch["mel"]
|
||||||
|
|
|
@ -410,7 +410,9 @@ class TestVits(unittest.TestCase):
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
batch = self._create_batch(config, 2)
|
batch = self._create_batch(config, 2)
|
||||||
for idx in [0, 1]:
|
for idx in [0, 1]:
|
||||||
_, loss_dict = model.train_step(batch, criterions, idx)
|
outputs, loss_dict = model.train_step(batch, criterions, idx)
|
||||||
|
self.assertFalse(not outputs)
|
||||||
|
self.assertFalse(not loss_dict)
|
||||||
loss_dict["loss"].backward()
|
loss_dict["loss"].backward()
|
||||||
optimizers[idx].step()
|
optimizers[idx].step()
|
||||||
optimizers[idx].zero_grad()
|
optimizers[idx].zero_grad()
|
||||||
|
|
Loading…
Reference in New Issue