bug fix for rwd discriminator

pull/1/head
erogol 2020-06-05 13:25:53 +02:00
parent 189227c741
commit c5c3f6094c
1 changed files with 5 additions and 1 deletions

View File

@ -327,8 +327,12 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
y_hat = model_G.pqmf_synthesis(y_hat)
y_G_sub = model_G.pqmf_analysis(y_G)
D_out_fake = model_D(y_hat)
if len(signature(model_D.forward).parameters) == 2:
D_out_fake = model_D(y_hat, c_G)
else:
D_out_fake = model_D(y_hat)
D_out_real = None
if c.use_feat_match_loss:
with torch.no_grad():
D_out_real = model_D(y_G)