From e97bb45abae249b28004f635cd95f38a20de9f28 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 10 Mar 2020 11:06:25 +0100 Subject: [PATCH] bug fixes and fixing unit tests --- tests/test_loader.py | 5 +++-- train.py | 2 +- utils/synthesis.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_loader.py b/tests/test_loader.py index d8727895..98e0bbce 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -141,7 +141,7 @@ class TestTTSDataset(unittest.TestCase): # check mel-spec correctness mel_spec = mel_input[0].cpu().numpy() - wav = self.ap.inv_mel_spectrogram(mel_spec.T) + wav = self.ap.inv_melspectrogram(mel_spec.T) self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader.wav') shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader.wav') @@ -199,7 +199,8 @@ class TestTTSDataset(unittest.TestCase): # check the second itme in the batch assert linear_input[1 - idx, -1].sum() == 0 assert mel_input[1 - idx, -1].sum() == 0 - assert stop_target[1 - idx, -1] == 1 + assert stop_target[1, mel_lengths[1]-1] == 1 + assert stop_target[1, mel_lengths[1]:].sum() == 0 assert len(mel_lengths.shape) == 1 # check batch zero-frame conditions (zero-frame disabled) diff --git a/train.py b/train.py index cf073956..f4ea6e70 100644 --- a/train.py +++ b/train.py @@ -470,7 +470,7 @@ def evaluate(model, criterion, ap, global_step, epoch): style_wav = c.get("style_wav_for_test") for idx, test_sentence in enumerate(test_sentences): try: - wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis( + wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis( model, test_sentence, c, diff --git a/utils/synthesis.py b/utils/synthesis.py index b4512dc6..75778805 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -58,7 +58,7 @@ def inv_spectrogram(postnet_output, ap, CONFIG): if CONFIG.model in ["Tacotron", "TacotronGST"]: wav = ap.inv_spectrogram(postnet_output.T) else: - wav = ap.inv_mel_spectrogram(postnet_output.T) + wav = ap.inv_melspectrogram(postnet_output.T) return wav