From 0bb8d780e8e5c1bb12b15073af20190f2d7c029c Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 5 Sep 2019 16:48:36 +0200 Subject: [PATCH] visual.py update --- utils/visual.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/utils/visual.py b/utils/visual.py index 1ee87cfb..825caf52 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -1,3 +1,4 @@ +import torch import librosa import matplotlib matplotlib.use('Agg') @@ -5,10 +6,14 @@ import matplotlib.pyplot as plt from TTS.utils.text import phoneme_to_sequence, sequence_to_phoneme -def plot_alignment(alignment, info=None): - fig, ax = plt.subplots(figsize=(16, 10)) +def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None): + if isinstance(alignment, torch.Tensor): + alignment_ = alignment.detach().cpu().numpy().squeeze() + else: + alignment_ = alignment + fig, ax = plt.subplots(figsize=fig_size) im = ax.imshow( - alignment.T, aspect='auto', origin='lower', interpolation='none') + alignment_.T, aspect='auto', origin='lower', interpolation='none') fig.colorbar(im, ax=ax) xlabel = 'Decoder timestep' if info is not None: @@ -17,12 +22,18 @@ def plot_alignment(alignment, info=None): plt.ylabel('Encoder timestep') # plt.yticks(range(len(text)), list(text)) plt.tight_layout() + if title is not None: + plt.title(title) return fig -def plot_spectrogram(linear_output, audio): - spectrogram = audio._denormalize(linear_output) - fig = plt.figure(figsize=(16, 10)) +def plot_spectrogram(linear_output, audio, fig_size=(16, 10)): + if isinstance(linear_output, torch.Tensor): + linear_output_ = linear_output.detach().cpu().numpy().squeeze() + else: + linear_output_ = linear_output + spectrogram = audio._denormalize(linear_output_) + fig = plt.figure(figsize=fig_size) plt.imshow(spectrogram.T, aspect="auto", origin="lower") plt.colorbar() plt.tight_layout()