mirror of https://github.com/coqui-ai/TTS.git
Fix unit test
parent
44456b0483
commit
212d330929
|
@ -155,6 +155,7 @@ class GAN(BaseVocoder):
|
|||
|
||||
if optimizer_idx == 1:
|
||||
# GENERATOR loss
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
if self.train_disc:
|
||||
if len(signature(self.model_d.forward).parameters) == 2:
|
||||
D_out_fake = self.model_d(self.y_hat_g, x)
|
||||
|
@ -182,7 +183,6 @@ class GAN(BaseVocoder):
|
|||
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
|
||||
)
|
||||
outputs = {"model_outputs": self.y_hat_g}
|
||||
|
||||
return outputs, loss_dict
|
||||
|
||||
@staticmethod
|
||||
|
@ -216,6 +216,7 @@ class GAN(BaseVocoder):
|
|||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Call `train_step()` with `no_grad()`"""
|
||||
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
|
||||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(
|
||||
|
|
Loading…
Reference in New Issue