mirror of https://github.com/coqui-ai/TTS.git
run D in eval
parent
2404f96cba
commit
8d307f2133
|
@ -56,8 +56,8 @@
|
||||||
|
|
||||||
"stft_loss_weight": 0.5,
|
"stft_loss_weight": 0.5,
|
||||||
"subband_stft_loss_weight": 0.5,
|
"subband_stft_loss_weight": 0.5,
|
||||||
"mse_gan_loss_weight": 2.5,
|
"mse_G_loss_weight": 2.5,
|
||||||
"hinge_gan_loss_weight": 2.5,
|
"hinge_G_loss_weight": 2.5,
|
||||||
"feat_match_loss_weight": 25,
|
"feat_match_loss_weight": 25,
|
||||||
|
|
||||||
"stft_loss_params": {
|
"stft_loss_params": {
|
||||||
|
|
|
@ -181,7 +181,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
##############################
|
##############################
|
||||||
# DISCRIMINATOR
|
# DISCRIMINATOR
|
||||||
##############################
|
##############################
|
||||||
if global_step > c.steps_to_start_discriminator:
|
if global_step >= c.steps_to_start_discriminator:
|
||||||
# discriminator pass
|
# discriminator pass
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
y_hat = model_G(c_D)
|
y_hat = model_G(c_D)
|
||||||
|
@ -295,7 +295,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
||||||
model_G.eval()
|
model_G.eval()
|
||||||
model_D.eval()
|
model_D.eval()
|
||||||
|
@ -355,13 +355,52 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
||||||
for key, value in loss_G_dict.items():
|
for key, value in loss_G_dict.items():
|
||||||
loss_dict[key] = value.item()
|
loss_dict[key] = value.item()
|
||||||
|
|
||||||
|
##############################
|
||||||
|
# DISCRIMINATOR
|
||||||
|
##############################
|
||||||
|
|
||||||
|
if global_step >= c.steps_to_start_discriminator:
|
||||||
|
# discriminator pass
|
||||||
|
with torch.no_grad():
|
||||||
|
y_hat = model_G(c_G)
|
||||||
|
|
||||||
|
# PQMF formatting
|
||||||
|
if y_hat.shape[1] > 1:
|
||||||
|
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||||
|
|
||||||
|
# run D with or without cond. features
|
||||||
|
if len(signature(model_D.forward).parameters) == 2:
|
||||||
|
D_out_fake = model_D(y_hat.detach(), c_G)
|
||||||
|
D_out_real = model_D(y_G, c_G)
|
||||||
|
else:
|
||||||
|
D_out_fake = model_D(y_hat.detach())
|
||||||
|
D_out_real = model_D(y_G)
|
||||||
|
|
||||||
|
# format D outputs
|
||||||
|
if isinstance(D_out_fake, tuple):
|
||||||
|
scores_fake, feats_fake = D_out_fake
|
||||||
|
if D_out_real is None:
|
||||||
|
scores_real, feats_real = None, None
|
||||||
|
else:
|
||||||
|
scores_real, feats_real = D_out_real
|
||||||
|
else:
|
||||||
|
scores_fake = D_out_fake
|
||||||
|
scores_real = D_out_real
|
||||||
|
|
||||||
|
# compute losses
|
||||||
|
loss_D_dict = criterion_D(scores_fake, scores_real)
|
||||||
|
|
||||||
|
for key, value in loss_D_dict.items():
|
||||||
|
loss_dict[key] = value.item()
|
||||||
|
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
# update avg stats
|
# update avg stats
|
||||||
update_eval_values = dict()
|
update_eval_values = dict()
|
||||||
for key, value in loss_G_dict.items():
|
for key, value in loss_dict.items():
|
||||||
update_eval_values['avg_' + key] = value.item()
|
update_eval_values['avg_' + key] = value
|
||||||
update_eval_values['avg_loader_time'] = loader_time
|
update_eval_values['avg_loader_time'] = loader_time
|
||||||
update_eval_values['avg_step_time'] = step_time
|
update_eval_values['avg_step_time'] = step_time
|
||||||
keep_avg.update_values(update_eval_values)
|
keep_avg.update_values(update_eval_values)
|
||||||
|
|
Loading…
Reference in New Issue