From c5c3f6094cfe2eadc16d2c970fbe95cb2b3bb917 Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 5 Jun 2020 13:25:53 +0200 Subject: [PATCH] bug fix for rwd discriminator --- vocoder/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vocoder/train.py b/vocoder/train.py index d1e6befe..82f3bee6 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -327,8 +327,12 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch): y_hat = model_G.pqmf_synthesis(y_hat) y_G_sub = model_G.pqmf_analysis(y_G) - D_out_fake = model_D(y_hat) + if len(signature(model_D.forward).parameters) == 2: + D_out_fake = model_D(y_hat, c_G) + else: + D_out_fake = model_D(y_hat) D_out_real = None + if c.use_feat_match_loss: with torch.no_grad(): D_out_real = model_D(y_G)