visualization fix

pull/1/head
erogol 2020-06-04 09:51:20 +02:00
parent 8dc1d6a208
commit 1b05bfce2e
1 changed files with 4 additions and 2 deletions

View File

@ -120,11 +120,13 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
y_hat = model_G(c_G)
y_hat_sub = None
y_G_sub = None
y_hat_vis = y_hat # for visualization
# PQMF formatting
if y_hat.shape[1] > 1:
y_hat_sub = y_hat
y_hat = model_G.pqmf_synthesis(y_hat)
y_hat_vis = y_hat
y_G_sub = model_G.pqmf_analysis(y_G)
if global_step > c.steps_to_start_discriminator:
@ -265,12 +267,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
model_losses=loss_dict)
# compute spectrograms
figures = plot_results(y_hat, y_G, ap, global_step,
figures = plot_results(y_hat_vis, y_G, ap, global_step,
'train')
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_train_audios(global_step,
{'train/audio': sample_voice},
c.audio["sample_rate"])