diff --git a/config.json b/config.json index a850c2eb..7eca8172 100644 --- a/config.json +++ b/config.json @@ -11,7 +11,7 @@ "embedding_size": 256, "text_cleaner": "english_cleaners", - "epochs": 200, + "epochs": 2000, "lr": 0.01, "lr_patience": 2, "lr_decay": 0.5, @@ -20,7 +20,7 @@ "power": 1.5, "r": 5, - "save_step": 1, + "save_step": 200, "data_path": "/data/shared/KeithIto/LJSpeech-1.0", "output_path": "result", "log_dir": "/home/erogol/projects/TTS/logs/" diff --git a/train.py b/train.py index 737fe43b..bb297233 100644 --- a/train.py +++ b/train.py @@ -42,7 +42,7 @@ def main(args): pickle.dump(c, open(tmp_path, "wb")) # setup tensorboard - LOG_DIR = c.log_dir + LOG_DIR = OUT_PATH tb = SummaryWriter(LOG_DIR) # Ctrl+C handler to remove empty experiment folder @@ -101,7 +101,7 @@ def main(args): lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay, patience=c.lr_patience, verbose=True) - + epoch_time = 0 for epoch in range(c.epochs): dataloader = DataLoader(dataset, batch_size=c.batch_size, @@ -166,14 +166,20 @@ def main(args): optimizer.step() - time_per_step = time.time() - start_time + step_time = time.time() - start_time + epoch_time += step_time + progbar.update(i+1, values=[('total_loss', loss.data[0]), ('linear_loss', linear_loss.data[0]), ('mel_loss', mel_loss.data[0])]) tb.add_scalar('Train/TotalLoss', loss.data[0], current_step) - tb.add_scalar('Train/LinearLoss', linear_loss.data[0], current_step) + tb.add_scalar('Train/LinearLoss', linear_loss.data[0], + current_step) tb.add_scalar('Train/MelLoss', mel_loss.data[0], current_step) + tb.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], + current_step) + tb.add_scalar('Time/StepTime', step_time, current_step) if current_step % c.save_step == 0: checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step) @@ -188,6 +194,8 @@ def main(args): checkpoint_path) print("\n | > Checkpoint is saved : {}".format(checkpoint_path)) lr_scheduler.step(loss.data[0]) + tb.add_scalar('Time/EpochTime', epoch_time, epoch) + epoch_time = 0 if __name__ == '__main__':