mirror of https://github.com/coqui-ai/TTS.git
run D in eval
parent
2404f96cba
commit
8d307f2133
|
@ -56,8 +56,8 @@
|
|||
|
||||
"stft_loss_weight": 0.5,
|
||||
"subband_stft_loss_weight": 0.5,
|
||||
"mse_gan_loss_weight": 2.5,
|
||||
"hinge_gan_loss_weight": 2.5,
|
||||
"mse_G_loss_weight": 2.5,
|
||||
"hinge_G_loss_weight": 2.5,
|
||||
"feat_match_loss_weight": 25,
|
||||
|
||||
"stft_loss_params": {
|
||||
|
|
|
@ -181,7 +181,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
##############################
|
||||
# DISCRIMINATOR
|
||||
##############################
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
if global_step >= c.steps_to_start_discriminator:
|
||||
# discriminator pass
|
||||
with torch.no_grad():
|
||||
y_hat = model_G(c_D)
|
||||
|
@ -295,7 +295,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
||||
def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
||||
model_G.eval()
|
||||
model_D.eval()
|
||||
|
@ -355,13 +355,52 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
|||
for key, value in loss_G_dict.items():
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
##############################
|
||||
# DISCRIMINATOR
|
||||
##############################
|
||||
|
||||
if global_step >= c.steps_to_start_discriminator:
|
||||
# discriminator pass
|
||||
with torch.no_grad():
|
||||
y_hat = model_G(c_G)
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat.detach(), c_G)
|
||||
D_out_real = model_D(y_G, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat.detach())
|
||||
D_out_real = model_D(y_G)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
|
||||
# compute losses
|
||||
loss_D_dict = criterion_D(scores_fake, scores_real)
|
||||
|
||||
for key, value in loss_D_dict.items():
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
update_eval_values['avg_' + key] = value.item()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values['avg_' + key] = value
|
||||
update_eval_values['avg_loader_time'] = loader_time
|
||||
update_eval_values['avg_step_time'] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
|
Loading…
Reference in New Issue