Tensorboard plotting

pull/10/head
Eren Golge 2018-01-25 07:07:46 -08:00
parent 2f92246c8a
commit 1fa791f83e
3 changed files with 13 additions and 2 deletions

View File

@ -25,5 +25,6 @@
"text_cleaner": "english_cleaners", "text_cleaner": "english_cleaners",
"data_path": "/data/shared/KeithIto/LJSpeech-1.0", "data_path": "/data/shared/KeithIto/LJSpeech-1.0",
"output_path": "result" "output_path": "result",
"log_dir": "/home/erogol/projects/TTS/logs/"
} }

View File

@ -13,6 +13,7 @@ import torch.nn as nn
from torch import optim from torch import optim
from torch.autograd import Variable from torch.autograd import Variable
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from utils.generic_utils import (Progbar, remove_experiment_folder, from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint, create_experiment_folder, save_checkpoint,
@ -38,6 +39,10 @@ def main(args):
tmp_path = os.path.join("/tmp/", file_name+'_tts') tmp_path = os.path.join("/tmp/", file_name+'_tts')
pickle.dump(c, open(tmp_path, "wb")) pickle.dump(c, open(tmp_path, "wb"))
# setup tensorboard
LOG_DIR = c.log_dir
tb = SummaryWriter(LOG_DIR)
# Ctrl+C handler to remove empty experiment folder # Ctrl+C handler to remove empty experiment folder
def signal_handler(signal, frame): def signal_handler(signal, frame):
print(" !! Pressed Ctrl+C !!") print(" !! Pressed Ctrl+C !!")
@ -78,7 +83,7 @@ def main(args):
print("\n > Model restored from step %d\n" % args.restore_step) print("\n > Model restored from step %d\n" % args.restore_step)
except: except:
print("\n > Starting a new training\n") print("\n > Starting a new training")
model = model.train() model = model.train()
@ -97,6 +102,7 @@ def main(args):
dataloader = DataLoader(dataset, batch_size=c.batch_size, dataloader = DataLoader(dataset, batch_size=c.batch_size,
shuffle=True, collate_fn=dataset.collate_fn, shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=32) drop_last=True, num_workers=32)
print("\n | > Epoch {}".format(epoch))
progbar = Progbar(len(dataset) / c.batch_size) progbar = Progbar(len(dataset) / c.batch_size)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
@ -160,6 +166,10 @@ def main(args):
('linear_loss', linear_loss.data[0]), ('linear_loss', linear_loss.data[0]),
('mel_loss', mel_loss.data[0])]) ('mel_loss', mel_loss.data[0])])
tb.add_scalar('Train/TotalLoss', loss.data[0], current_step)
tb.add_scalar('Train/LinearLoss', linear_loss.data[0], current_step)
tb.add_scalar('Train/MelLoss', mel_loss.data[0], current_step)
if current_step % c.save_step == 0: if current_step % c.save_step == 0:
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step) checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(OUT_PATH, checkpoint_path) checkpoint_path = os.path.join(OUT_PATH, checkpoint_path)

Binary file not shown.