Fix unit test

pull/1544/head
Edresson Casanova 2022-04-29 16:29:44 -03:00
parent 44456b0483
commit 212d330929
1 changed files with 2 additions and 1 deletions

View File

@ -155,6 +155,7 @@ class GAN(BaseVocoder):
if optimizer_idx == 1:
# GENERATOR loss
scores_fake, feats_fake, feats_real = None, None, None
if self.train_disc:
if len(signature(self.model_d.forward).parameters) == 2:
D_out_fake = self.model_d(self.y_hat_g, x)
@ -182,7 +183,6 @@ class GAN(BaseVocoder):
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
)
outputs = {"model_outputs": self.y_hat_g}
return outputs, loss_dict
@staticmethod
@ -216,6 +216,7 @@ class GAN(BaseVocoder):
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Call `train_step()` with `no_grad()`"""
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(