diff --git a/vocoder/configs/melgan_config.json b/vocoder/configs/melgan_config.json index bdfdf88a..0b6b16db 100644 --- a/vocoder/configs/melgan_config.json +++ b/vocoder/configs/melgan_config.json @@ -56,8 +56,8 @@ "stft_loss_weight": 0.5, "subband_stft_loss_weight": 0.5, - "mse_gan_loss_weight": 2.5, - "hinge_gan_loss_weight": 2.5, + "mse_G_loss_weight": 2.5, + "hinge_G_loss_weight": 2.5, "feat_match_loss_weight": 25, "stft_loss_params": { diff --git a/vocoder/train.py b/vocoder/train.py index 05805c68..54b36e6b 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -181,7 +181,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, ############################## # DISCRIMINATOR ############################## - if global_step > c.steps_to_start_discriminator: + if global_step >= c.steps_to_start_discriminator: # discriminator pass with torch.no_grad(): y_hat = model_G(c_D) @@ -295,7 +295,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, @torch.no_grad() -def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch): +def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch): data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) model_G.eval() model_D.eval() @@ -355,13 +355,52 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch): for key, value in loss_G_dict.items(): loss_dict[key] = value.item() + ############################## + # DISCRIMINATOR + ############################## + + if global_step >= c.steps_to_start_discriminator: + # discriminator pass + with torch.no_grad(): + y_hat = model_G(c_G) + + # PQMF formatting + if y_hat.shape[1] > 1: + y_hat = model_G.pqmf_synthesis(y_hat) + + # run D with or without cond. features + if len(signature(model_D.forward).parameters) == 2: + D_out_fake = model_D(y_hat.detach(), c_G) + D_out_real = model_D(y_G, c_G) + else: + D_out_fake = model_D(y_hat.detach()) + D_out_real = model_D(y_G) + + # format D outputs + if isinstance(D_out_fake, tuple): + scores_fake, feats_fake = D_out_fake + if D_out_real is None: + scores_real, feats_real = None, None + else: + scores_real, feats_real = D_out_real + else: + scores_fake = D_out_fake + scores_real = D_out_real + + # compute losses + loss_D_dict = criterion_D(scores_fake, scores_real) + + for key, value in loss_D_dict.items(): + loss_dict[key] = value.item() + + step_time = time.time() - start_time epoch_time += step_time # update avg stats update_eval_values = dict() - for key, value in loss_G_dict.items(): - update_eval_values['avg_' + key] = value.item() + for key, value in loss_dict.items(): + update_eval_values['avg_' + key] = value update_eval_values['avg_loader_time'] = loader_time update_eval_values['avg_step_time'] = step_time keep_avg.update_values(update_eval_values)