run D in eval

pull/1/head
erogol 2020-06-08 10:43:09 +02:00
parent 2404f96cba
commit 8d307f2133
2 changed files with 45 additions and 6 deletions

View File

@ -56,8 +56,8 @@
"stft_loss_weight": 0.5, "stft_loss_weight": 0.5,
"subband_stft_loss_weight": 0.5, "subband_stft_loss_weight": 0.5,
"mse_gan_loss_weight": 2.5, "mse_G_loss_weight": 2.5,
"hinge_gan_loss_weight": 2.5, "hinge_G_loss_weight": 2.5,
"feat_match_loss_weight": 25, "feat_match_loss_weight": 25,
"stft_loss_params": { "stft_loss_params": {

View File

@ -181,7 +181,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
############################## ##############################
# DISCRIMINATOR # DISCRIMINATOR
############################## ##############################
if global_step > c.steps_to_start_discriminator: if global_step >= c.steps_to_start_discriminator:
# discriminator pass # discriminator pass
with torch.no_grad(): with torch.no_grad():
y_hat = model_G(c_D) y_hat = model_G(c_D)
@ -295,7 +295,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
@torch.no_grad() @torch.no_grad()
def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch): def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
model_G.eval() model_G.eval()
model_D.eval() model_D.eval()
@ -355,13 +355,52 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
for key, value in loss_G_dict.items(): for key, value in loss_G_dict.items():
loss_dict[key] = value.item() loss_dict[key] = value.item()
##############################
# DISCRIMINATOR
##############################
if global_step >= c.steps_to_start_discriminator:
# discriminator pass
with torch.no_grad():
y_hat = model_G(c_G)
# PQMF formatting
if y_hat.shape[1] > 1:
y_hat = model_G.pqmf_synthesis(y_hat)
# run D with or without cond. features
if len(signature(model_D.forward).parameters) == 2:
D_out_fake = model_D(y_hat.detach(), c_G)
D_out_real = model_D(y_G, c_G)
else:
D_out_fake = model_D(y_hat.detach())
D_out_real = model_D(y_G)
# format D outputs
if isinstance(D_out_fake, tuple):
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
scores_real, feats_real = None, None
else:
scores_real, feats_real = D_out_real
else:
scores_fake = D_out_fake
scores_real = D_out_real
# compute losses
loss_D_dict = criterion_D(scores_fake, scores_real)
for key, value in loss_D_dict.items():
loss_dict[key] = value.item()
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
# update avg stats # update avg stats
update_eval_values = dict() update_eval_values = dict()
for key, value in loss_G_dict.items(): for key, value in loss_dict.items():
update_eval_values['avg_' + key] = value.item() update_eval_values['avg_' + key] = value
update_eval_values['avg_loader_time'] = loader_time update_eval_values['avg_loader_time'] = loader_time
update_eval_values['avg_step_time'] = step_time update_eval_values['avg_step_time'] = step_time
keep_avg.update_values(update_eval_values) keep_avg.update_values(update_eval_values)