pull/1/head
erogol 2020-05-31 03:31:42 +02:00
parent 0bb0ba182e
commit a361df3186
3 changed files with 9 additions and 10 deletions

View File

@ -109,14 +109,13 @@
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
// OPTIMIZER
"noam_schedule": true, // use noam warmup and lr schedule.
"grad_clip": 1.0, // upper limit for gradients for clipping.
"noam_schedule": false, // use noam warmup and lr schedule.
"warmup_steps_gen": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"warmup_steps_disc": 4000,
"epochs": 1000, // total number of epochs to train.
"wd": 0.000001, // Weight decay weight.
"lr_gen": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_disc": 0.0001,
"warmup_steps_gen": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"warmup_steps_disc": 4000,
"gen_clip_grad": 10.0,
"disc_clip_grad": 10.0,

View File

@ -372,11 +372,11 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
# compute spectrograms
figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
figures = plot_results(in_fake_D, in_real_D, ap, global_step, 'eval')
tb_logger.tb_eval_figures(global_step, figures)
# Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
sample_voice = in_fake_D[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
c.audio["sample_rate"])

View File

@ -13,8 +13,8 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
y_hat = y_hat[0].squeeze(0).detach().cpu().numpy()
y = y[0].squeeze(0).detach().cpu().numpy()
spec_fake = ap.spectrogram(y_hat).T
spec_real = ap.spectrogram(y).T
spec_fake = ap.melspectrogram(y_hat).T
spec_real = ap.melspectrogram(y).T
spec_diff = np.abs(spec_fake - spec_real)
# plot figure and save it
@ -98,5 +98,5 @@ def setup_discriminator(c):
return model
# def check_config(c):
# pass
def check_config(c):
pass