logger for tensorboard plotting

pull/10/head
Eren Golge 2018-12-13 18:18:37 +01:00
parent 268ca36295
commit 062e8a0880
4 changed files with 118 additions and 83 deletions

View File

@ -40,6 +40,7 @@
"checkpoint": true,
"save_step": 5000,
"print_step": 10,
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"run_eval": true,
"data_path": "../../Data/LJSpeech-1.1/", // can overwritten from command argument

View File

@ -8,5 +8,4 @@ tensorboardX
matplotlib==2.0.2
Pillow
flask
scipy==0.19.0
lws
scipy==0.19.0

122
train.py
View File

@ -22,6 +22,7 @@ from models.tacotron import Tacotron
from layers.losses import L1LossMasked
from utils.audio import AudioProcessor
from utils.synthesis import synthesis
from utils.logger import Logger
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
@ -169,15 +170,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
avg_step_time += step_time
# Plot Training Iter Stats
tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step)
tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.item(),
current_step)
tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.item(), current_step)
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
current_step)
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
tb.add_scalar('Params/GradNormSt', grad_norm_st, current_step)
tb.add_scalar('Time/StepTime', step_time, current_step)
iter_stats = {"loss_posnet": linear_loss.item(),
"loss_decoder": mel_loss.item(),
"lr": current_lr,
"grad_norm": grad_norm,
"grad_norm_st": grad_norm_st,
"step_time": step_time}
tb_logger.tb_train_iter_stats(current_step, iter_stats)
if current_step % c.save_step == 0:
if c.checkpoint:
@ -189,28 +188,17 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
# Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, ap)
gt_spec = plot_spectrogram(gt_spec, ap)
tb.add_figure('Visual/Reconstruction', const_spec, current_step)
tb.add_figure('Visual/GroundTruth', gt_spec, current_step)
align_img = alignments[0].data.cpu().numpy()
align_img = plot_alignment(align_img)
tb.add_figure('Visual/Alignment', align_img, current_step)
figures = {"prediction": plot_spectrogram(const_spec, ap),
"ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img)}
tb_logger.tb_train_figures(figures, current_step)
# Sample audio
audio_signal = linear_output[0].data.cpu().numpy()
ap.griffin_lim_iters = 60
audio_signal = ap.inv_spectrogram(audio_signal.T)
try:
tb.add_audio(
'SampleAudio',
audio_signal,
current_step,
sample_rate=c.sample_rate)
except:
pass
tb_logger.tb_train_audios(current_step,
{'TrainAudio': ap.inv_spectrogram(const_spec.T)},
c.sample_rate)
avg_linear_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1)
@ -229,12 +217,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
flush=True)
# Plot Training Epoch Stats
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
epoch_time = 0
epoch_stats = {"loss_postnet": avg_linear_loss,
"loss_decoder": avg_mel_loss,
"stop_loss": avg_stop_loss,
"epoch_time": epoch_time}
tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
if c.tb_model_param_stats:
tb_logger.tb_model_weights(model, current_step)
return avg_linear_loss, current_step
@ -316,74 +305,45 @@ def evaluate(model, criterion, criterion_st, ap, current_step):
gt_spec = linear_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, ap)
gt_spec = plot_spectrogram(gt_spec, ap)
align_img = plot_alignment(align_img)
tb.add_figure('ValVisual/Reconstruction', const_spec, current_step)
tb.add_figure('ValVisual/GroundTruth', gt_spec, current_step)
tb.add_figure('ValVisual/ValidationAlignment', align_img,
current_step)
eval_figures = {"prediction": plot_spectrogram(const_spec, ap),
"ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img)}
tb_logger.tb_eval_figures(current_step, eval_figures)
# Sample audio
audio_signal = linear_output[idx].data.cpu().numpy()
ap.griffin_lim_iters = 60
audio_signal = ap.inv_spectrogram(audio_signal.T)
try:
tb.add_audio(
'ValSampleAudio',
audio_signal,
current_step,
sample_rate=c.audio["sample_rate"])
except:
# sometimes audio signal is out of boundaries
pass
tb_logger.tb_eval_audios(current_step, {"ValAudio": ap.inv_spectrogram(const_spec.T)}, c.audio["sample_rate"])
# compute average losses
avg_linear_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1)
avg_stop_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
# Plot Learning Stats
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss,
current_step)
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss,
current_step)
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss,
current_step)
# Plot Validation Stats
epoch_stats = {"loss_postnet": avg_linear_loss,
"loss_decoder": avg_mel_loss,
"stop_loss": avg_stop_loss}
tb_logger.tb_eval_stats(current_step, epoch_stats)
# test sentences
ap.griffin_lim_iters = 60
test_audios = {}
test_figures = {}
for idx, test_sentence in enumerate(test_sentences):
try:
wav, alignment, linear_spec, _, stop_tokens = synthesis(
model, test_sentence, c, use_cuda, ap)
file_path = os.path.join(AUDIO_PATH, str(current_step))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path,
"TestSentence_{}.wav".format(idx))
ap.save_wav(wav, file_path)
wav_name = 'TestSentences/{}'.format(idx)
tb.add_audio(
wav_name,
wav,
current_step,
sample_rate=c.audio['sample_rate'])
linear_spec = plot_spectrogram(linear_spec, ap)
align_img = plot_alignment(alignment)
tb.add_figure('TestSentences/{}_Spectrogram'.format(idx),
linear_spec, current_step)
tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img,
current_step)
test_audios['{}-audio'.format(idx)] = wav
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(linear_spec, ap)
test_figures['{}-alignment'.format(idx)] = plot_alignment(alignment)
except:
print(" !! Error creating Test Sentence -", idx)
traceback.print_exc()
pass
tb_logger.tb_test_audios(current_step, test_audios, c.audio['sample_rate'])
tb_logger.tb_test_figures(current_step, test_figures)
return avg_linear_loss
@ -496,7 +456,7 @@ if __name__ == '__main__':
# setup tensorboard
LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
tb_logger = Logger(LOG_DIR)
# Conditional imports
preprocessor = importlib.import_module('datasets.preprocess')

75
utils/logger.py Normal file
View File

@ -0,0 +1,75 @@
import traceback
from tensorboardX import SummaryWriter
class Logger(object):
def __init__(self, log_dir):
self.writer = SummaryWriter(log_dir)
self.train_stats = {}
self.eval_stats = {}
def tb_model_weights(self, model, step):
layer_num = 1
for name, param in model.named_parameters():
self.writer.add_scalar(
"layer{}-ModelParams/{}/max".format(layer_num, name),
param.max(), step)
self.writer.add_scalar(
"layer{}-ModelParams/{}/min".format(layer_num, name),
param.min(), step)
self.writer.add_scalar(
"layer{}-ModelParams/{}/mean".format(layer_num, name),
param.mean(), step)
self.writer.add_scalar(
"layer{}-ModelParams/{}/std".format(layer_num, name),
param.std(), step)
self.writer.add_histogram(
"layer{}-{}/param".format(layer_num, name), param, step)
self.writer.add_histogram(
"layer{}-{}/grad".format(layer_num, name), param.grad, step)
layer_num += 1
def dict_to_tb_scalar(self, scope_name, stats, step):
for key, value in stats.items():
self.writer.add_scalar('{}/{}'.format(scope_name, key), value, step)
def dict_to_tb_figure(self, scope_name, figures, step):
for key, value in figures.items():
self.writer.add_figure('{}/{}'.format(scope_name, key), value, step)
def dict_to_tb_audios(self, scope_name, audios, step, sample_rate):
for key, value in audios.items():
try:
self.writer.add_audio('{}/{}'.format(scope_name, key), value, step, sample_rate=sample_rate)
except:
traceback.print_exc()
def tb_train_iter_stats(self, step, stats):
self.dict_to_tb_scalar("TrainIterStats", stats, step)
def tb_train_epoch_stats(self, step, stats):
self.dict_to_tb_scalar("TrainEpochStats", stats, step)
def tb_train_figures(self, step, figures):
self.dict_to_tb_figure("TrainFigures", figures, step)
def tb_train_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios("TrainAudios", audios, step, sample_rate)
def tb_eval_stats(self, step, stats):
self.dict_to_tb_scalar("EvalStats", stats, step)
def tb_eval_figures(self, step, figures):
self.dict_to_tb_figure("EvalFigures", figures, step)
def tb_eval_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios("EvalAudios", audios, step, sample_rate)
def tb_test_audios(self, step, audios, sample_rate):
self.dict_to_tb_audios("TestAudios", audios, step, sample_rate)
def tb_test_figures(self, step, figures):
self.dict_to_tb_figure("TestFigures", figures, step)