mirror of https://github.com/coqui-ai/TTS.git
Tensorboard plotting
parent
2f92246c8a
commit
1fa791f83e
|
@ -25,5 +25,6 @@
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||||
"output_path": "result"
|
"output_path": "result",
|
||||||
|
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||||
}
|
}
|
||||||
|
|
12
train.py
12
train.py
|
@ -13,6 +13,7 @@ import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint,
|
create_experiment_folder, save_checkpoint,
|
||||||
|
@ -38,6 +39,10 @@ def main(args):
|
||||||
tmp_path = os.path.join("/tmp/", file_name+'_tts')
|
tmp_path = os.path.join("/tmp/", file_name+'_tts')
|
||||||
pickle.dump(c, open(tmp_path, "wb"))
|
pickle.dump(c, open(tmp_path, "wb"))
|
||||||
|
|
||||||
|
# setup tensorboard
|
||||||
|
LOG_DIR = c.log_dir
|
||||||
|
tb = SummaryWriter(LOG_DIR)
|
||||||
|
|
||||||
# Ctrl+C handler to remove empty experiment folder
|
# Ctrl+C handler to remove empty experiment folder
|
||||||
def signal_handler(signal, frame):
|
def signal_handler(signal, frame):
|
||||||
print(" !! Pressed Ctrl+C !!")
|
print(" !! Pressed Ctrl+C !!")
|
||||||
|
@ -78,7 +83,7 @@ def main(args):
|
||||||
print("\n > Model restored from step %d\n" % args.restore_step)
|
print("\n > Model restored from step %d\n" % args.restore_step)
|
||||||
|
|
||||||
except:
|
except:
|
||||||
print("\n > Starting a new training\n")
|
print("\n > Starting a new training")
|
||||||
|
|
||||||
model = model.train()
|
model = model.train()
|
||||||
|
|
||||||
|
@ -97,6 +102,7 @@ def main(args):
|
||||||
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
drop_last=True, num_workers=32)
|
drop_last=True, num_workers=32)
|
||||||
|
print("\n | > Epoch {}".format(epoch))
|
||||||
progbar = Progbar(len(dataset) / c.batch_size)
|
progbar = Progbar(len(dataset) / c.batch_size)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
|
@ -160,6 +166,10 @@ def main(args):
|
||||||
('linear_loss', linear_loss.data[0]),
|
('linear_loss', linear_loss.data[0]),
|
||||||
('mel_loss', mel_loss.data[0])])
|
('mel_loss', mel_loss.data[0])])
|
||||||
|
|
||||||
|
tb.add_scalar('Train/TotalLoss', loss.data[0], current_step)
|
||||||
|
tb.add_scalar('Train/LinearLoss', linear_loss.data[0], current_step)
|
||||||
|
tb.add_scalar('Train/MelLoss', mel_loss.data[0], current_step)
|
||||||
|
|
||||||
if current_step % c.save_step == 0:
|
if current_step % c.save_step == 0:
|
||||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||||
checkpoint_path = os.path.join(OUT_PATH, checkpoint_path)
|
checkpoint_path = os.path.join(OUT_PATH, checkpoint_path)
|
||||||
|
|
Binary file not shown.
Loading…
Reference in New Issue