From c8a552e62763ed7172b98822bb65838d72fbdb7b Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 2 Nov 2018 16:13:51 +0100 Subject: [PATCH] Batch update after data-loss --- config.json | 70 +- datasets/{LJSpeech.py => TTSDataset.py} | 47 +- ...{LJSpeechCached.py => TTSDatasetCached.py} | 93 +-- datasets/{Kusal.py => TTSDatasetMemory.py} | 341 +++++----- datasets/preprocess.py | 60 ++ extract_features.py | 138 ++++ layers/attention.py | 17 +- layers/tacotron.py | 93 ++- setup.py | 4 +- tests/audio_tests.py | 142 ++++ tests/loader_tests.py | 615 ++++++++++++------ tests/tacotron_tests.py | 6 +- tests/test_config.json | 38 +- train.py | 54 +- utils/audio.py | 169 ++++- utils/generic_utils.py | 32 +- utils/synthesis.py | 23 + utils/visual.py | 27 + 18 files changed, 1362 insertions(+), 607 deletions(-) rename datasets/{LJSpeech.py => TTSDataset.py} (77%) rename datasets/{LJSpeechCached.py => TTSDatasetCached.py} (64%) rename datasets/{Kusal.py => TTSDatasetMemory.py} (56%) create mode 100644 datasets/preprocess.py create mode 100644 extract_features.py create mode 100644 tests/audio_tests.py create mode 100644 utils/synthesis.py diff --git a/config.json b/config.json index f0615e40..1491601d 100644 --- a/config.json +++ b/config.json @@ -1,41 +1,47 @@ { - "model_name": "TTS-sigmoid", - "model_description": "Net outputting Sigmoid unit", - "audio_processor": "audio", - "num_mels": 80, - "num_freq": 1025, - "sample_rate": 22000, - "frame_length_ms": 50, - "frame_shift_ms": 12.5, - "preemphasis": 0.97, - "min_level_db": -100, - "ref_level_db": 20, + "model_name": "TTS-master-_tmp", + "model_description": "Higher dropout rate for stopnet and disabled custom initialization, pull current mel prediction to stopnet.", + + "audio":{ + "audio_processor": "audio", // to use dictate different audio processors, if available. + "num_mels": 80, // size of the mel spec frame. + "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame. + "sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled. + "frame_length_ms": 50, // stft window length in ms. + "frame_shift_ms": 12.5, // stft window hop-lengh in ms. + "preemphasis": 0.97, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "min_level_db": -100, // normalization range + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + "power": 1.5, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + "signal_norm": true, // normalize the spec values in range [0, 1] + "symmetric_norm": false, // move normalization to range [-1, 1] + "max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "mel_fmin": null, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": null // maximum freq level for mel-spec. Tune for dataset!! + }, + "embedding_size": 256, "text_cleaner": "english_cleaners", - - "num_loader_workers": 4, - "epochs": 1000, - "lr": 0.002, + "lr": 0.0015, "warmup_steps": 4000, - "lr_decay": 0.5, - "decay_step": 100000, - "batch_size": 32, - "eval_batch_size":-1, - "r": 5, - "wd": 0.0001, - - "griffin_lim_iters": 60, - "power": 1.5, - + "batch_size":32, + "eval_batch_size":32, + "r": 1, + "wd": 0.000001, "checkpoint": true, - "save_step": 25000, + "save_step": 5000, "print_step": 10, - "run_eval": false, - "data_path": "/snakepit/shared/data/keithito/LJSpeech-1.1/", - "meta_file_train": "metadata.csv", - "meta_file_val": null, - "dataset": "LJSpeech", + + "run_eval": true, + "data_path": "../../Data/LJSpeech-1.1/tts_cache", // can overwritten from command argument + "meta_file_train": "metadata_train.csv", // metafile for training dataloader + "meta_file_val": "metadata_val.csv", // metafile for validation dataloader + "data_loader": "TTSDataset", // dataloader, ["TTSDataset", "TTSDatasetCached", "TTSDatasetMemory"] + "dataset": "ljspeech", // one of TTS.dataset.preprocessors, only valid id dataloader == "TTSDataset", rest uses "tts_cache" by default. "min_seq_len": 0, - "output_path": "../keep/" + "output_path": "../keep/", + "num_loader_workers": 8 } \ No newline at end of file diff --git a/datasets/LJSpeech.py b/datasets/TTSDataset.py similarity index 77% rename from datasets/LJSpeech.py rename to datasets/TTSDataset.py index 495283e1..4b30710d 100644 --- a/datasets/LJSpeech.py +++ b/datasets/TTSDataset.py @@ -13,75 +13,72 @@ from utils.data import (prepare_data, pad_per_step, prepare_tensor, class MyDataset(Dataset): def __init__(self, - root_dir, - csv_file, + root_path, + meta_file, outputs_per_step, text_cleaner, ap, + preprocessor, batch_group_size=0, min_seq_len=0): - self.root_dir = root_dir + self.root_path = root_path self.batch_group_size = batch_group_size - self.wav_dir = os.path.join(root_dir, 'wavs') - self.csv_dir = os.path.join(root_dir, csv_file) - with open(self.csv_dir, "r", encoding="utf8") as f: - self.frames = [line.split('|') for line in f] + self.items = preprocessor(root_path, meta_file) self.outputs_per_step = outputs_per_step self.sample_rate = ap.sample_rate self.cleaners = text_cleaner self.min_seq_len = min_seq_len self.ap = ap - print(" > Reading LJSpeech from - {}".format(root_dir)) - print(" | > Number of instances : {}".format(len(self.frames))) - self.sort_frames() + print(" > Reading LJSpeech from - {}".format(root_path)) + print(" | > Number of instances : {}".format(len(self.items))) + self.sort_items() def load_wav(self, filename): try: - audio = librosa.core.load(filename, sr=self.sample_rate)[0] + audio = self.ap.load_wav(filename) return audio except RuntimeError as e: print(" !! Cannot read file : {}".format(filename)) - def sort_frames(self): + def sort_items(self): r"""Sort text sequences in ascending order""" - lengths = np.array([len(ins[1]) for ins in self.frames]) + lengths = np.array([len(ins[0]) for ins in self.items]) print(" | > Max length sequence {}".format(np.max(lengths))) print(" | > Min length sequence {}".format(np.min(lengths))) print(" | > Avg length sequence {}".format(np.mean(lengths))) idxs = np.argsort(lengths) - new_frames = [] + new_items = [] ignored = [] for i, idx in enumerate(idxs): length = lengths[idx] if length < self.min_seq_len: ignored.append(idx) else: - new_frames.append(self.frames[idx]) + new_items.append(self.items[idx]) print(" | > {} instances are ignored by min_seq_len ({})".format( len(ignored), self.min_seq_len)) # shuffle batch groups if self.batch_group_size > 0: print(" | > Batch group shuffling is active.") - for i in range(len(new_frames) // self.batch_group_size): + for i in range(len(new_items) // self.batch_group_size): offset = i * self.batch_group_size end_offset = offset + self.batch_group_size - temp_frames = new_frames[offset : end_offset] - random.shuffle(temp_frames) - new_frames[offset : end_offset] = temp_frames - self.frames = new_frames + temp_items = new_items[offset : end_offset] + random.shuffle(temp_items) + new_items[offset : end_offset] = temp_items + self.items = new_items def __len__(self): - return len(self.frames) + return len(self.items) def __getitem__(self, idx): - wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav' - text = self.frames[idx][1] + text, wav_file = self.items[idx] text = np.asarray( text_to_sequence(text, [self.cleaners]), dtype=np.int32) - wav = np.asarray(self.load_wav(wav_name), dtype=np.float32) - sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} + wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) + sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]} return sample def collate_fn(self, batch): diff --git a/datasets/LJSpeechCached.py b/datasets/TTSDatasetCached.py similarity index 64% rename from datasets/LJSpeechCached.py rename to datasets/TTSDatasetCached.py index 09061cc3..57a58f55 100644 --- a/datasets/LJSpeechCached.py +++ b/datasets/TTSDatasetCached.py @@ -1,4 +1,5 @@ import os +import random import numpy as np import collections import librosa @@ -6,32 +7,34 @@ import torch from torch.utils.data import Dataset from utils.text import text_to_sequence +from datasets.preprocess import tts_cache from utils.data import (prepare_data, pad_per_step, prepare_tensor, prepare_stop_target) class MyDataset(Dataset): + # TODO: Not finished yet. def __init__(self, - root_dir, - csv_file, + root_path, + meta_file, outputs_per_step, text_cleaner, ap, - min_seq_len=0): - self.root_dir = root_dir - self.wav_dir = os.path.join(root_dir, 'wavs') - self.feat_dir = os.path.join(root_dir, 'loader_data') - self.csv_dir = os.path.join(root_dir, csv_file) - with open(self.csv_dir, "r", encoding="utf8") as f: - self.frames = [line.split('|') for line in f] + batch_group_size=0, + min_seq_len=0, + **kwargs + ): + self.root_path = root_path + self.batch_group_size = batch_group_size + self.feat_dir = os.path.join(root_path, 'loader_data') + self.items = tts_cache(root_path, meta_file) self.outputs_per_step = outputs_per_step self.sample_rate = ap.sample_rate self.cleaners = text_cleaner self.min_seq_len = min_seq_len - self.items = [None] * len(self.frames) - print(" > Reading LJSpeech from - {}".format(root_dir)) - print(" | > Number of instances : {}".format(len(self.frames))) - self._sort_frames() + print(" > Reading LJSpeech from - {}".format(root_path)) + print(" | > Number of instances : {}".format(len(self.items))) + self.sort_items() def load_wav(self, filename): try: @@ -44,9 +47,9 @@ class MyDataset(Dataset): data = np.load(filename).astype('float32') return data - def _sort_frames(self): - r"""Sort sequences in ascending order""" - lengths = np.array([len(ins[1]) for ins in self.frames]) + def sort_items(self): + r"""Sort text sequences in ascending order""" + lengths = np.array([len(ins[-1]) for ins in self.items]) print(" | > Max length sequence {}".format(np.max(lengths))) print(" | > Min length sequence {}".format(np.min(lengths))) @@ -60,37 +63,43 @@ class MyDataset(Dataset): if length < self.min_seq_len: ignored.append(idx) else: - new_frames.append(self.frames[idx]) + new_frames.append(self.items[idx]) print(" | > {} instances are ignored by min_seq_len ({})".format( len(ignored), self.min_seq_len)) - self.frames = new_frames + # shuffle batch groups + if self.batch_group_size > 0: + print(" | > Batch group shuffling is active.") + for i in range(len(new_frames) // self.batch_group_size): + offset = i * self.batch_group_size + end_offset = offset + self.batch_group_size + temp_frames = new_frames[offset : end_offset] + random.shuffle(temp_frames) + new_frames[offset : end_offset] = temp_frames + self.items = new_frames def __len__(self): - return len(self.frames) + return len(self.items) def __getitem__(self, idx): - if self.items[idx] is None: - wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav' - mel_name = os.path.join(self.feat_dir, - self.frames[idx][0]) + '.mel.npy' - linear_name = os.path.join(self.feat_dir, - self.frames[idx][0]) + '.linear.npy' - text = self.frames[idx][1] - text = np.asarray( - text_to_sequence(text, [self.cleaners]), dtype=np.int32) - wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) - mel = self.load_np(mel_name) - linear = self.load_np(linear_name) - sample = { - 'text': text, - 'wav': wav, - 'item_idx': self.frames[idx][0], - 'mel': mel, - 'linear': linear - } - self.items[idx] = sample + wav_name = self.items[idx][0] + mel_name = self.items[idx][1] + linear_name = self.items[idx][2] + text = self.items[idx][-1] + text = np.asarray( + text_to_sequence(text, [self.cleaners]), dtype=np.int32) + if wav_name.split('.')[-1] == 'npy': + wav = self.load_np(wav_name) else: - sample = self.items[idx] + wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) + mel = self.load_np(mel_name) + linear = self.load_np(linear_name) + sample = { + 'text': text, + 'wav': wav, + 'item_idx': self.items[idx][0], + 'mel': mel, + 'linear': linear + } return sample def collate_fn(self, batch): @@ -132,7 +141,6 @@ class MyDataset(Dataset): # PAD features with largest length + a zero frame linear = prepare_tensor(linear, self.outputs_per_step) mel = prepare_tensor(mel, self.outputs_per_step) - assert mel.shape[2] == linear.shape[2] timesteps = mel.shape[2] # B x T x D @@ -147,8 +155,7 @@ class MyDataset(Dataset): mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) - return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[ - 0] + return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ found {}".format(type(batch[0])))) diff --git a/datasets/Kusal.py b/datasets/TTSDatasetMemory.py similarity index 56% rename from datasets/Kusal.py rename to datasets/TTSDatasetMemory.py index b431aee3..1d770dfd 100644 --- a/datasets/Kusal.py +++ b/datasets/TTSDatasetMemory.py @@ -1,162 +1,179 @@ -import os -import glob -import random -import numpy as np -import collections -import librosa -import torch -from torch.utils.data import Dataset - -from utils.text import text_to_sequence -from utils.data import (prepare_data, pad_per_step, prepare_tensor, - prepare_stop_target) - - -class MyDataset(Dataset): - def __init__(self, - root_dir, - csv_file, - outputs_per_step, - text_cleaner, - ap, - min_seq_len=0): - self.root_dir = root_dir - self.wav_dir = os.path.join(root_dir, 'wav') - self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav')) - self._create_file_dict() - self.csv_dir = os.path.join(root_dir, csv_file) - with open(self.csv_dir, "r", encoding="utf8") as f: - self.frames = [ - line.split('\t') for line in f - if line.split('\t')[0] in self.wav_files_dict.keys() - ] - self.outputs_per_step = outputs_per_step - self.sample_rate = ap.sample_rate - self.cleaners = text_cleaner - self.min_seq_len = min_seq_len - self.ap = ap - print(" > Reading Kusal from - {}".format(root_dir)) - print(" | > Number of instances : {}".format(len(self.frames))) - self._sort_frames() - - def load_wav(self, filename): - """ Load audio and trim silence """ - try: - audio = librosa.core.load(filename, sr=self.sample_rate)[0] - margin = int(self.sample_rate * 0.1) - audio = audio[margin:-margin] - return self._trim_silence(audio) - except RuntimeError as e: - print(" !! Cannot read file : {}".format(filename)) - - def _trim_silence(self, wav): - return librosa.effects.trim( - wav, top_db=40, frame_length=1024, hop_length=256)[0] - - def _create_file_dict(self): - self.wav_files_dict = {} - for fn in self.wav_files: - parts = fn.split('-') - key = parts[1] - value = fn - try: - self.wav_files_dict[key].append(value) - except: - self.wav_files_dict[key] = [value] - - def _sort_frames(self): - r"""Sort sequences in ascending order""" - lengths = np.array([len(ins[2]) for ins in self.frames]) - - print(" | > Max length sequence {}".format(np.max(lengths))) - print(" | > Min length sequence {}".format(np.min(lengths))) - print(" | > Avg length sequence {}".format(np.mean(lengths))) - - idxs = np.argsort(lengths) - new_frames = [] - ignored = [] - for i, idx in enumerate(idxs): - length = lengths[idx] - if length < self.min_seq_len: - ignored.append(idx) - else: - new_frames.append(self.frames[idx]) - print(" | > {} instances are ignored by min_seq_len ({})".format( - len(ignored), self.min_seq_len)) - self.frames = new_frames - - def __len__(self): - return len(self.frames) - - def __getitem__(self, idx): - sidx = self.frames[idx][0] - sidx_files = self.wav_files_dict[sidx] - file_name = random.choice(sidx_files) - wav_name = os.path.join(self.wav_dir, file_name) - text = self.frames[idx][2] - text = np.asarray( - text_to_sequence(text, [self.cleaners]), dtype=np.int32) - wav = np.asarray(self.load_wav(wav_name), dtype=np.float32) - sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} - return sample - - def collate_fn(self, batch): - r""" - Perform preprocessing and create a final data batch: - 1. PAD sequences with the longest sequence in the batch - 2. Convert Audio signal to Spectrograms. - 3. PAD sequences that can be divided by r. - 4. Convert Numpy to Torch tensors. - """ - - # Puts each data field into a tensor with outer dimension batch size - if isinstance(batch[0], collections.Mapping): - keys = list() - - wav = [d['wav'] for d in batch] - item_idxs = [d['item_idx'] for d in batch] - text = [d['text'] for d in batch] - - text_lenghts = np.array([len(x) for x in text]) - max_text_len = np.max(text_lenghts) - - linear = [self.ap.spectrogram(w).astype('float32') for w in wav] - mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] - mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame - - # compute 'stop token' targets - stop_targets = [ - np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths - ] - - # PAD stop targets - stop_targets = prepare_stop_target(stop_targets, - self.outputs_per_step) - - # PAD sequences with largest length of the batch - text = prepare_data(text).astype(np.int32) - wav = prepare_data(wav) - - # PAD features with largest length + a zero frame - linear = prepare_tensor(linear, self.outputs_per_step) - mel = prepare_tensor(mel, self.outputs_per_step) - assert mel.shape[2] == linear.shape[2] - timesteps = mel.shape[2] - - # B x T x D - linear = linear.transpose(0, 2, 1) - mel = mel.transpose(0, 2, 1) - - # convert things to pytorch - text_lenghts = torch.LongTensor(text_lenghts) - text = torch.LongTensor(text) - linear = torch.FloatTensor(linear) - mel = torch.FloatTensor(mel) - mel_lengths = torch.LongTensor(mel_lengths) - stop_targets = torch.FloatTensor(stop_targets) - - return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[ - 0] - - raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ - found {}".format(type(batch[0])))) +import os +import random +import numpy as np +import collections +import librosa +import torch +from tqdm import tqdm +from torch.utils.data import Dataset + +from utils.text import text_to_sequence +from datasets.preprocess import tts_cache +from utils.data import (prepare_data, pad_per_step, prepare_tensor, + prepare_stop_target) + + +class MyDataset(Dataset): + # TODO: Not finished yet. + def __init__(self, + root_path, + meta_file, + outputs_per_step, + text_cleaner, + ap, + batch_group_size=0, + min_seq_len=0, + **kwargs + ): + self.root_path = root_path + self.batch_group_size = batch_group_size + self.feat_dir = os.path.join(root_path, 'loader_data') + self.items = tts_cache(root_path, meta_file) + self.outputs_per_step = outputs_per_step + self.sample_rate = ap.sample_rate + self.cleaners = text_cleaner + self.min_seq_len = min_seq_len + self.wavs = None + self.mels = None + self.linears = None + print(" > Reading LJSpeech from - {}".format(root_path)) + print(" | > Number of instances : {}".format(len(self.items))) + self.sort_items() + self.fill_data() + + def fill_data(self): + if self.wavs is None and self.mels is None: + self.wavs = [] + self.mels = [] + self.linears = [] + self.texts = [] + for item in tqdm(self.items): + wav_file = item[0] + mel_file = item[1] + linear_file = item[2] + text = item[-1] + wav = self.load_np(wav_file) + mel = self.load_np(mel_file) + linear = self.load_np(linear_file) + self.wavs.append(wav) + self.mels.append(mel) + self.linears.append(linear) + self.texts.append(np.asarray( + text_to_sequence(text, [self.cleaners]), dtype=np.int32)) + print(" > Data loaded to memory") + + def load_wav(self, filename): + try: + audio = librosa.core.load(filename, sr=self.sample_rate) + return audio + except RuntimeError as e: + print(" !! Cannot read file : {}".format(filename)) + + def load_np(self, filename): + data = np.load(filename).astype('float32') + return data + + def sort_items(self): + r"""Sort text sequences in ascending order""" + lengths = np.array([len(ins[-1]) for ins in self.items]) + + print(" | > Max length sequence {}".format(np.max(lengths))) + print(" | > Min length sequence {}".format(np.min(lengths))) + print(" | > Avg length sequence {}".format(np.mean(lengths))) + + idxs = np.argsort(lengths) + new_frames = [] + ignored = [] + for i, idx in enumerate(idxs): + length = lengths[idx] + if length < self.min_seq_len: + ignored.append(idx) + else: + new_frames.append(self.items[idx]) + print(" | > {} instances are ignored by min_seq_len ({})".format( + len(ignored), self.min_seq_len)) + # shuffle batch groups + if self.batch_group_size > 0: + print(" | > Batch group shuffling is active.") + for i in range(len(new_frames) // self.batch_group_size): + offset = i * self.batch_group_size + end_offset = offset + self.batch_group_size + temp_frames = new_frames[offset : end_offset] + random.shuffle(temp_frames) + new_frames[offset : end_offset] = temp_frames + self.items = new_frames + + def __len__(self): + return len(self.items) + + def __getitem__(self, idx): + text = self.texts[idx] + wav = self.wavs[idx] + mel = self.mels[idx] + linear = self.linears[idx] + sample = { + 'text': text, + 'wav': wav, + 'item_idx': self.items[idx][0], + 'mel': mel, + 'linear': linear + } + return sample + + def collate_fn(self, batch): + r""" + Perform preprocessing and create a final data batch: + 1. PAD sequences with the longest sequence in the batch + 2. Convert Audio signal to Spectrograms. + 3. PAD sequences that can be divided by r. + 4. Convert Numpy to Torch tensors. + """ + + # Puts each data field into a tensor with outer dimension batch size + if isinstance(batch[0], collections.Mapping): + keys = list() + + wav = [d['wav'] for d in batch] + item_idxs = [d['item_idx'] for d in batch] + text = [d['text'] for d in batch] + mel = [d['mel'] for d in batch] + linear = [d['linear'] for d in batch] + + text_lenghts = np.array([len(x) for x in text]) + max_text_len = np.max(text_lenghts) + mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame + + # compute 'stop token' targets + stop_targets = [ + np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths + ] + + # PAD stop targets + stop_targets = prepare_stop_target(stop_targets, + self.outputs_per_step) + + # PAD sequences with largest length of the batch + text = prepare_data(text).astype(np.int32) + wav = prepare_data(wav) + + # PAD features with largest length + a zero frame + linear = prepare_tensor(linear, self.outputs_per_step) + mel = prepare_tensor(mel, self.outputs_per_step) + timesteps = mel.shape[2] + + # B x T x D + linear = linear.transpose(0, 2, 1) + mel = mel.transpose(0, 2, 1) + + # convert things to pytorch + text_lenghts = torch.LongTensor(text_lenghts) + text = torch.LongTensor(text) + linear = torch.FloatTensor(linear) + mel = torch.FloatTensor(mel) + mel_lengths = torch.LongTensor(mel_lengths) + stop_targets = torch.FloatTensor(stop_targets) + + return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs + + raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ + found {}".format(type(batch[0])))) diff --git a/datasets/preprocess.py b/datasets/preprocess.py new file mode 100644 index 00000000..0aabbdf6 --- /dev/null +++ b/datasets/preprocess.py @@ -0,0 +1,60 @@ +import os +import random + +def tts_cache(root_path, meta_file): + """This format is set for the meta-file generated by extract_features.py""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, 'r', encoding='utf8') as f: + for line in f: + cols = line.split('| ') + items.append(cols) # wav_full_path, mel_name, linear_name, wav_len, mel_len, text + random.shuffle(items) + return items + + +# def tweb(root_path, meta_file): +# # TODO +# pass +# return + + +# def kusal(root_path, meta_file): +# txt_file = os.path.join(root_path, meta_file) +# texts = [] +# wavs = [] +# with open(txt_file, "r", encoding="utf8") as f: +# frames = [ +# line.split('\t') for line in f +# if line.split('\t')[0] in self.wav_files_dict.keys() +# ] +# # TODO: code the rest +# return {'text': texts, 'wavs': wavs} + + +def ljspeech(root_path, meta_file): + """Normalizes the Nancy meta data file to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, 'r') as ttf: + for line in ttf: + cols = line.split('|') + wav_file = os.path.join(root_path, 'wavs', cols[0]+'.wav') + text = cols[1] + items.append([text, wav_file]) + random.shuffle(items) + return items + + +def nancy(root_path, meta_file): + """Normalizes the Nancy meta data file to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, 'r') as ttf: + for line in ttf: + id = line.split()[1] + text = line[line.find('"')+1:line.rfind('"')-1] + wav_file = root_path + 'wavn/' + id + '.wav' + items.append(text, wav_file) + random.shuffle(items) + return items \ No newline at end of file diff --git a/extract_features.py b/extract_features.py new file mode 100644 index 00000000..56629d1d --- /dev/null +++ b/extract_features.py @@ -0,0 +1,138 @@ +''' +Extract spectrograms and save them to file for training +''' +import os +import sys +import time +import glob +import argparse +import librosa +import importlib +import numpy as np +import tqdm +from utils.generic_utils import load_config, copy_config_file +from utils.audio import AudioProcessor + +from multiprocessing import Pool + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--data_path', type=str, help='Data folder.') + parser.add_argument('--cache_path', type=str, help='Cache folder, place to output all the intermediate spectrogram files.') + # parser.add_argument('--keep_cache', type=bool, help='If True, it keeps the cache folder.') + # parser.add_argument('--hdf5_path', type=str, help='hdf5 folder.') + parser.add_argument( + '--config', type=str, help='conf.json file for run settings.') + parser.add_argument( + "--num_proc", type=int, default=8, help="number of processes.") + parser.add_argument( + "--trim_silence", + type=bool, + default=False, + help="trim silence in the voice clip.") + parser.add_argument("--only_mel", type=bool, default=False, help="If True, only melsceptrogram is extracted.") + parser.add_argument("--dataset", type=str, help="Target dataset to be processed.") + parser.add_argument("--val_split", type=int, default=0, help="Number of instances for validation.") + parser.add_argument("--meta_file", type=str, help="Meta data file to be used for the dataset.") + parser.add_argument("--process_audio", type=bool, default=False, help="Preprocess audio files.") + args = parser.parse_args() + + DATA_PATH = args.data_path + CACHE_PATH = args.cache_path + CONFIG = load_config(args.config) + + # load the right preprocessor + preprocessor = importlib.import_module('datasets.preprocess') + preprocessor = getattr(preprocessor, args.dataset.lower()) + items = preprocessor(args.data_path, args.meta_file) + + print(" > Input path: ", DATA_PATH) + print(" > Cache path: ", CACHE_PATH) + + # audio = importlib.import_module('utils.' + c.audio_processor) + # AudioProcessor = getattr(audio, 'AudioProcessor') + ap = AudioProcessor(**CONFIG.audio) + + def trim_silence(self, wav): + """ Trim silent parts with a threshold and 0.1 sec margin """ + margin = int(ap.sample_rate * 0.1) + wav = wav[margin:-margin] + return librosa.effects.trim( + wav, top_db=40, frame_length=1024, hop_length=256)[0] + + def extract_mel(item): + """ Compute spectrograms, length information """ + text = item[0] + file_path = item[1] + x = ap.load_wav(file_path, ap.sample_rate) + if args.trim_silence: + x = trim_silence(x) + file_name = os.path.basename(file_path).replace(".wav", "") + mel_file = file_name + "_mel" + mel_path = os.path.join(CACHE_PATH, 'mel', mel_file) + mel = ap.melspectrogram(x.astype('float32')).astype('float32') + np.save(mel_path, mel, allow_pickle=False) + mel_len = mel.shape[1] + wav_len = x.shape[0] + output = [file_path, mel_path+".npy", str(wav_len), str(mel_len), text] + if not args.only_mel: + linear_file = file_name + "_linear" + linear_path = os.path.join(CACHE_PATH, 'linear', linear_file) + linear = ap.spectrogram(x.astype('float32')).astype('float32') + linear_len = linear.shape[1] + np.save(linear_path, linear, allow_pickle=False) + output.insert(2, linear_path+".npy") + if args.process_audio: + audio_file = file_name + "_audio" + audio_path = os.path.join(CACHE_PATH, 'audio', audio_file) + np.save(audio_path, x, allow_pickle=False) + del output[0] + output.insert(0, audio_path+".npy") + assert mel_len == linear_len + return output + + + if __name__ == "__main__": + print(" > Number of files: %i" % (len(items))) + if not os.path.exists(CACHE_PATH): + os.makedirs(os.path.join(CACHE_PATH, 'mel')) + if not args.only_mel: + os.makedirs(os.path.join(CACHE_PATH, 'linear')) + if args.process_audio: + os.makedirs(os.path.join(CACHE_PATH, 'audio')) + print(" > A new folder created at {}".format(CACHE_PATH)) + + # Extract features + r = [] + if args.num_proc > 1: + print(" > Using {} processes.".format(args.num_proc)) + with Pool(args.num_proc) as p: + r = list( + tqdm.tqdm( + p.imap(extract_mel, items), + total=len(items))) + # r = list(p.imap(extract_mel, file_names)) + else: + print(" > Using single process run.") + for item in items: + print(" > ", item[1]) + r.append(extract_mel(item)) + + # Save meta data + if args.cache_path is not None: + file_path = os.path.join(CACHE_PATH, "tts_metadata_val.csv") + file = open(file_path, "w") + for line in r[:args.val_split]: + line = "| ".join(line) + file.write(line + '\n') + file.close() + + file_path = os.path.join(CACHE_PATH, "tts_metadata.csv") + file = open(file_path, "w") + for line in r[args.val_split:]: + line = "| ".join(line) + file.write(line + '\n') + file.close() + + # copy the used config file to output path for sanity + copy_config_file(args.config, CACHE_PATH) diff --git a/layers/attention.py b/layers/attention.py index bbf83055..5436f110 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -52,10 +52,25 @@ class LocationSensitiveAttention(nn.Module): stride=1, padding=0, bias=False)) - self.loc_linear = nn.Linear(filters, attn_dim) + self.loc_linear = nn.Linear(filters, attn_dim, bias=True) self.query_layer = nn.Linear(query_dim, attn_dim, bias=True) self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True) self.v = nn.Linear(attn_dim, 1, bias=False) + # self.init_layers() + + def init_layers(self): + torch.nn.init.xavier_uniform_( + self.loc_linear.weight, + gain=torch.nn.init.calculate_gain('tanh')) + torch.nn.init.xavier_uniform_( + self.query_layer.weight, + gain=torch.nn.init.calculate_gain('tanh')) + torch.nn.init.xavier_uniform_( + self.annot_layer.weight, + gain=torch.nn.init.calculate_gain('tanh')) + torch.nn.init.xavier_uniform_( + self.v.weight, + gain=torch.nn.init.calculate_gain('linear')) def forward(self, annot, query, loc): """ diff --git a/layers/tacotron.py b/layers/tacotron.py index 23b705c2..749d7cb3 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -23,6 +23,13 @@ class Prenet(nn.Module): ]) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.5) + # self.init_layers() + + def init_layers(self): + for layer in self.layers: + torch.nn.init.xavier_uniform_( + layer.weight, + gain=torch.nn.init.calculate_gain('relu')) def forward(self, inputs): for linear in self.layers: @@ -55,6 +62,7 @@ class BatchNormConv1d(nn.Module): stride, padding, activation=None): + super(BatchNormConv1d, self).__init__() self.padding = padding self.padder = nn.ConstantPad1d(padding, 0) @@ -68,13 +76,28 @@ class BatchNormConv1d(nn.Module): # Following tensorflow's default parameters self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) self.activation = activation + # self.init_layers() + + def init_layers(self): + if type(self.activation) == torch.nn.ReLU: + w_gain = 'relu' + elif type(self.activation) == torch.nn.Tanh: + w_gain = 'tanh' + elif self.activation is None: + w_gain = 'linear' + else: + raise RuntimeError('Unknown activation function') + torch.nn.init.xavier_uniform_( + self.conv1d.weight, + gain=torch.nn.init.calculate_gain(w_gain)) def forward(self, x): x = self.padder(x) x = self.conv1d(x) + x = self.bn(x) if self.activation is not None: x = self.activation(x) - return self.bn(x) + return x class Highway(nn.Module): @@ -86,6 +109,15 @@ class Highway(nn.Module): self.T.bias.data.fill_(-1) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() + # self.init_layers() + + def init_layers(self): + torch.nn.init.xavier_uniform_( + self.H.weight, + gain=torch.nn.init.calculate_gain('relu')) + torch.nn.init.xavier_uniform_( + self.T.weight, + gain=torch.nn.init.calculate_gain('sigmoid')) def forward(self, inputs): H = self.relu(self.H(inputs)) @@ -276,7 +308,7 @@ class Decoder(nn.Module): super(Decoder, self).__init__() self.r = r self.in_features = in_features - self.max_decoder_steps = 200 + self.max_decoder_steps = 500 self.memory_dim = memory_dim # memory -> |Prenet| -> processed_memory self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) @@ -294,7 +326,16 @@ class Decoder(nn.Module): [nn.GRUCell(256, 256) for _ in range(2)]) # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) - self.stopnet = StopNet(r, memory_dim) + self.stopnet = StopNet(256 + memory_dim * r) + # self.init_layers() + + def init_layers(self): + torch.nn.init.xavier_uniform_( + self.project_to_decoder_in.weight, + gain=torch.nn.init.calculate_gain('linear')) + torch.nn.init.xavier_uniform_( + self.proj_to_mel.weight, + gain=torch.nn.init.calculate_gain('linear')) def forward(self, inputs, memory=None, mask=None): """ @@ -350,7 +391,7 @@ class Decoder(nn.Module): memory_input = initial_memory while True: if t > 0: - if greedy: + if memory is None: memory_input = outputs[-1] else: memory_input = memory[t - 1] @@ -375,17 +416,14 @@ class Decoder(nn.Module): decoder_output = decoder_input # predict mel vectors from decoder vectors output = self.proj_to_mel(decoder_output) - output = torch.sigmoid(output) - stop_input = output # predict stop token - stop_token, stopnet_rnn_hidden = self.stopnet( - stop_input, stopnet_rnn_hidden) + stopnet_input = torch.cat([decoder_input, output], -1) + stop_token = self.stopnet(stopnet_input) outputs += [output] attentions += [attention] stop_tokens += [stop_token] t += 1 - if (not greedy and self.training) or (greedy - and memory is not None): + if memory is not None: if t >= T_decoder: break else: @@ -394,7 +432,6 @@ class Decoder(nn.Module): elif t > self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") break - assert greedy or len(outputs) == T_decoder # Back to batch first attentions = torch.stack(attentions).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() @@ -407,32 +444,22 @@ class StopNet(nn.Module): Predicting stop-token in decoder. Args: - r (int): number of output frames of the network. - memory_dim (int): feature dimension for each output frame. + in_features (int): feature dimension of input. """ - def __init__(self, r, memory_dim): - r""" - Predicts the stop token to stop the decoder at testing time - - Args: - r (int): number of network output frames. - memory_dim (int): single feature dim of a single network output frame. - """ + def __init__(self, in_features): super(StopNet, self).__init__() - self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r) - self.relu = nn.ReLU() - self.linear = nn.Linear(r * memory_dim, 1) + self.dropout = nn.Dropout(0.5) + self.linear = nn.Linear(in_features, 1) self.sigmoid = nn.Sigmoid() + torch.nn.init.xavier_uniform_( + self.linear.weight, + gain=torch.nn.init.calculate_gain('linear')) - def forward(self, inputs, rnn_hidden): - """ - Args: - inputs: network output tensor with r x memory_dim feature dimension. - rnn_hidden: hidden state of the RNN cell. - """ - rnn_hidden = self.rnn(inputs, rnn_hidden) - outputs = self.relu(rnn_hidden) + def forward(self, inputs): + # rnn_hidden = self.rnn(inputs, rnn_hidden) + # outputs = self.relu(rnn_hidden) + outputs = self.dropout(inputs) outputs = self.linear(outputs) outputs = self.sigmoid(outputs) - return outputs, rnn_hidden + return outputs \ No newline at end of file diff --git a/setup.py b/setup.py index 29d4c632..1d091d2b 100644 --- a/setup.py +++ b/setup.py @@ -80,11 +80,11 @@ setup( "matplotlib==2.0.2", "Pillow", "flask", - "lws", + # "lws", + "tqdm", ], extras_require={ "bin": [ - "tqdm", "requests", ], }) diff --git a/tests/audio_tests.py b/tests/audio_tests.py new file mode 100644 index 00000000..fb5f698e --- /dev/null +++ b/tests/audio_tests.py @@ -0,0 +1,142 @@ +import os +import unittest +import numpy as np +import torch as T +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import load_config + +file_path = os.path.dirname(os.path.realpath(__file__)) +INPUTPATH = os.path.join(file_path, 'inputs') +OUTPATH = os.path.join(file_path, "outputs/audio_tests") +os.makedirs(OUTPATH, exist_ok=True) + +c = load_config(os.path.join(file_path, 'test_config.json')) + + +class TestAudio(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(TestAudio, self).__init__(*args, **kwargs) + self.ap = AudioProcessor(**c.audio) + + def test_audio_synthesis(self): + """ 1. load wav + 2. set normalization parameters + 3. extract mel-spec + 4. invert to wav and save the output + """ + print(" > Sanity check for the process wav -> mel -> wav") + + def _test(max_norm, signal_norm, symmetric_norm, clip_norm): + self.ap.max_norm = max_norm + self.ap.signal_norm = signal_norm + self.ap.symmetric_norm = symmetric_norm + self.ap.clip_norm = clip_norm + wav = self.ap.load_wav(INPUTPATH + "/example_1.wav") + mel = self.ap.melspectrogram(wav) + wav_ = self.ap.inv_mel_spectrogram(mel) + file_name = "/audio_test-melspec_max_norm_{}-signal_norm_{}-symmetric_{}-clip_norm_{}.wav"\ + .format(max_norm, signal_norm, symmetric_norm, clip_norm) + print(" | > Creating wav file at : ", file_name) + self.ap.save_wav(wav_, OUTPATH + file_name) + + # maxnorm = 1.0 + _test(1., False, False, False) + _test(1., True, False, False) + _test(1., True, True, False) + _test(1., True, False, True) + _test(1., True, True, True) + # maxnorm = 4.0 + _test(4., False, False, False) + _test(4., True, False, False) + _test(4., True, True, False) + _test(4., True, False, True) + _test(4., True, True, True) + + def test_normalize(self): + """Check normalization and denormalization for range values and consistency """ + print(" > Testing normalization and denormalization.") + wav = self.ap.load_wav(INPUTPATH + "/example_1.wav") + self.ap.signal_norm = False + x = self.ap.melspectrogram(wav) + x_old = x + + self.ap.signal_norm = True + self.ap.symmetric_norm = False + self.ap.clip_norm = False + self.ap.max_norm = 4.0 + x_norm = self.ap._normalize(x) + print(x_norm.max(), " -- ", x_norm.min()) + assert (x_old - x).sum() == 0 + # check value range + assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max() + assert x_norm.min() >= 0 - 1, x_norm.min() + # check denorm. + x_ = self.ap._denormalize(x_norm) + assert (x - x_).sum() < 1e-3, (x - x_).mean() + + self.ap.signal_norm = True + self.ap.symmetric_norm = False + self.ap.clip_norm = True + self.ap.max_norm = 4.0 + x_norm = self.ap._normalize(x) + print(x_norm.max(), " -- ", x_norm.min()) + assert (x_old - x).sum() == 0 + # check value range + assert x_norm.max() <= self.ap.max_norm, x_norm.max() + assert x_norm.min() >= 0, x_norm.min() + # check denorm. + x_ = self.ap._denormalize(x_norm) + assert (x - x_).sum() < 1e-3, (x - x_).mean() + + self.ap.signal_norm = True + self.ap.symmetric_norm = True + self.ap.clip_norm = False + self.ap.max_norm = 4.0 + x_norm = self.ap._normalize(x) + print(x_norm.max(), " -- ", x_norm.min()) + assert (x_old - x).sum() == 0 + # check value range + assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max() + assert x_norm.min() >= -self.ap.max_norm - 2, x_norm.min() + assert x_norm.min() <= 0, x_norm.min() + # check denorm. + x_ = self.ap._denormalize(x_norm) + assert (x - x_).sum() < 1e-3, (x - x_).mean() + + self.ap.signal_norm = True + self.ap.symmetric_norm = True + self.ap.clip_norm = True + self.ap.max_norm = 4.0 + x_norm = self.ap._normalize(x) + print(x_norm.max(), " -- ", x_norm.min()) + assert (x_old - x).sum() == 0 + # check value range + assert x_norm.max() <= self.ap.max_norm, x_norm.max() + assert x_norm.min() >= -self.ap.max_norm, x_norm.min() + assert x_norm.min() <= 0, x_norm.min() + # check denorm. + x_ = self.ap._denormalize(x_norm) + assert (x - x_).sum() < 1e-3, (x - x_).mean() + + self.ap.signal_norm = True + self.ap.symmetric_norm = False + self.ap.max_norm = 1.0 + x_norm = self.ap._normalize(x) + print(x_norm.max(), " -- ", x_norm.min()) + assert (x_old - x).sum() == 0 + assert x_norm.max() <= self.ap.max_norm, x_norm.max() + assert x_norm.min() >= 0, x_norm.min() + x_ = self.ap._denormalize(x_norm) + assert (x - x_).sum() < 1e-3 + + self.ap.signal_norm = True + self.ap.symmetric_norm = True + self.ap.max_norm = 1.0 + x_norm = self.ap._normalize(x) + print(x_norm.max(), " -- ", x_norm.min()) + assert (x_old - x).sum() == 0 + assert x_norm.max() <= self.ap.max_norm, x_norm.max() + assert x_norm.min() >= -self.ap.max_norm, x_norm.min() + assert x_norm.min() < 0, x_norm.min() + x_ = self.ap._denormalize(x_norm) + assert (x - x_).sum() < 1e-3 diff --git a/tests/loader_tests.py b/tests/loader_tests.py index b80f3e74..81ab59d4 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -1,50 +1,49 @@ import os import unittest +import shutil import numpy as np from torch.utils.data import DataLoader from TTS.utils.generic_utils import load_config from TTS.utils.audio import AudioProcessor -from TTS.datasets import LJSpeech, Kusal +from TTS.datasets import TTSDataset, TTSDatasetCached, TTSDatasetMemory +from TTS.datasets.preprocess import ljspeech, tts_cache file_path = os.path.dirname(os.path.realpath(__file__)) +OUTPATH = os.path.join(file_path, "outputs/loader_tests/") +os.makedirs(OUTPATH, exist_ok=True) c = load_config(os.path.join(file_path, 'test_config.json')) -ok_kusal = os.path.exists(c.data_path_Kusal) -ok_ljspeech = os.path.exists(c.data_path_LJSpeech) +ok_ljspeech = os.path.exists(c.data_path) -class TestLJSpeechDataset(unittest.TestCase): +class TestTTSDataset(unittest.TestCase): def __init__(self, *args, **kwargs): - super(TestLJSpeechDataset, self).__init__(*args, **kwargs) + super(TestTTSDataset, self).__init__(*args, **kwargs) self.max_loader_iter = 4 - self.ap = AudioProcessor( - sample_rate=c.sample_rate, - num_mels=c.num_mels, - min_level_db=c.min_level_db, - frame_shift_ms=c.frame_shift_ms, - frame_length_ms=c.frame_length_ms, - ref_level_db=c.ref_level_db, - num_freq=c.num_freq, - power=c.power, - preemphasis=c.preemphasis) + self.ap = AudioProcessor(**c.audio) + + def _create_dataloader(self, batch_size, r, bgs): + dataset = TTSDataset.MyDataset( + c.data_path, + 'metadata.csv', + r, + c.text_cleaner, + preprocessor=ljspeech, + ap=self.ap, + batch_group_size=bgs, + min_seq_len=c.min_seq_len) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=True, + num_workers=c.num_loader_workers) + return dataloader, dataset def test_loader(self): if ok_ljspeech: - dataset = LJSpeech.MyDataset( - os.path.join(c.data_path_LJSpeech), - os.path.join(c.data_path_LJSpeech, 'metadata.csv'), - c.r, - c.text_cleaner, - ap=self.ap, - min_seq_len=c.min_seq_len) - - dataloader = DataLoader( - dataset, - batch_size=2, - shuffle=True, - collate_fn=dataset.collate_fn, - drop_last=True, - num_workers=c.num_loader_workers) + dataloader, dataset = self._create_dataloader(2, c.r, 0) for i, data in enumerate(dataloader): if i == self.max_loader_iter: @@ -63,29 +62,158 @@ class TestLJSpeechDataset(unittest.TestCase): " !! Negative values in text_input: {}".format(check_count) # TODO: more assertion here assert linear_input.shape[0] == c.batch_size + assert linear_input.shape[2] == self.ap.num_freq assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.num_mels + assert mel_input.shape[2] == c.audio['num_mels'] + # check normalization ranges + if self.ap.symmetric_norm: + assert mel_input.max() <= self.ap.max_norm + assert mel_input.min() >= -self.ap.max_norm + assert mel_input.min() < 0 + else: + assert mel_input.max() <= self.ap.max_norm + assert mel_input.min() >= 0 def test_batch_group_shuffle(self): if ok_ljspeech: - dataset = LJSpeech.MyDataset( - os.path.join(c.data_path_LJSpeech), - os.path.join(c.data_path_LJSpeech, 'metadata.csv'), - c.r, - c.text_cleaner, - ap=self.ap, - batch_group_size=16, - min_seq_len=c.min_seq_len) + dataloader, dataset = self._create_dataloader(2, c.r, 16) + last_length = 0 + frames = dataset.items + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] - dataloader = DataLoader( - dataset, - batch_size=2, - shuffle=True, - collate_fn=dataset.collate_fn, - drop_last=True, - num_workers=c.num_loader_workers) + avg_length = mel_lengths.numpy().mean() + assert avg_length >= last_length + dataloader.dataset.sort_items() + assert frames[0] != dataloader.dataset.items[0] - frames = dataset.frames + def test_padding_and_spec(self): + if ok_ljspeech: + dataloader, dataset = self._create_dataloader(1, 1, 0) + + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + + # check mel_spec consistency + wav = self.ap.load_wav(item_idx[0]) + mel = self.ap.melspectrogram(wav) + mel_dl = mel_input[0].cpu().numpy() + assert ( + abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0 + + # check mel-spec correctness + mel_spec = mel_input[0].cpu().numpy() + wav = self.ap.inv_mel_spectrogram(mel_spec.T) + self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader.wav') + shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader.wav') + + # check linear-spec + linear_spec = linear_input[0].cpu().numpy() + wav = self.ap.inv_spectrogram(linear_spec.T) + self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav') + shutil.copy(item_idx[0], OUTPATH + '/linear_target_dataloader.wav') + + # check the last time step to be zero padded + assert linear_input[0, -1].sum() == 0 + assert linear_input[0, -2].sum() != 0 + assert mel_input[0, -1].sum() == 0 + assert mel_input[0, -2].sum() != 0 + assert stop_target[0, -1] == 1 + assert stop_target[0, -2] == 0 + assert stop_target.sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[0] == linear_input[0].shape[0] + assert mel_lengths[0] == mel_input[0].shape[0] + + # Test for batch size 2 + dataloader, dataset = self._create_dataloader(2, 1, 0) + + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + + if mel_lengths[0] > mel_lengths[1]: + idx = 0 + else: + idx = 1 + + # check the first item in the batch + assert linear_input[idx, -1].sum() == 0 + assert linear_input[idx, -2].sum() != 0, linear_input + assert mel_input[idx, -1].sum() == 0 + assert mel_input[idx, -2].sum() != 0, mel_input + assert stop_target[idx, -1] == 1 + assert stop_target[idx, -2] == 0 + assert stop_target[idx].sum() == 1 + assert len(mel_lengths.shape) == 1 + assert mel_lengths[idx] == mel_input[idx].shape[0] + assert mel_lengths[idx] == linear_input[idx].shape[0] + + # check the second itme in the batch + assert linear_input[1 - idx, -1].sum() == 0 + assert mel_input[1 - idx, -1].sum() == 0 + assert stop_target[1 - idx, -1] == 1 + assert len(mel_lengths.shape) == 1 + + # check batch conditions + assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 + assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 + + +class TestTTSDatasetCached(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(TestTTSDatasetCached, self).__init__(*args, **kwargs) + self.max_loader_iter = 4 + self.c = load_config(os.path.join(c.data_path_cache, 'config.json')) + self.ap = AudioProcessor(**self.c.audio) + + def _create_dataloader(self, batch_size, r, bgs): + + dataset = TTSDatasetCached.MyDataset( + c.data_path_cache, + 'tts_metadata.csv', + r, + c.text_cleaner, + preprocessor=tts_cache, + ap=self.ap, + batch_group_size=bgs, + min_seq_len=c.min_seq_len) + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=True, + num_workers=c.num_loader_workers) + return dataloader, dataset + + def test_loader(self): + if ok_ljspeech: + dataloader, dataset = self._create_dataloader(2, c.r, 0) for i, data in enumerate(dataloader): if i == self.max_loader_iter: break @@ -102,32 +230,21 @@ class TestLJSpeechDataset(unittest.TestCase): assert check_count == 0, \ " !! Negative values in text_input: {}".format(check_count) # TODO: more assertion here - assert linear_input.shape[0] == c.batch_size assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.num_mels - dataloader.dataset.sort_frames() - assert frames[0] != dataloader.dataset.frames[0] + assert mel_input.shape[2] == c.audio['num_mels'] + if self.ap.symmetric_norm: + assert mel_input.max() <= self.ap.max_norm + assert mel_input.min() >= -self.ap.max_norm + assert mel_input.min() < 0 + else: + assert mel_input.max() <= self.ap.max_norm + assert mel_input.min() >= 0 - def test_padding(self): + def test_batch_group_shuffle(self): if ok_ljspeech: - dataset = LJSpeech.MyDataset( - os.path.join(c.data_path_LJSpeech), - os.path.join(c.data_path_LJSpeech, 'metadata.csv'), - 1, - c.text_cleaner, - ap=self.ap, - min_seq_len=c.min_seq_len) - - # Test for batch size 1 - dataloader = DataLoader( - dataset, - batch_size=1, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=True, - num_workers=c.num_loader_workers) - + dataloader, dataset = self._create_dataloader(2, c.r, 16) + frames = dataset.items for i, data in enumerate(dataloader): if i == self.max_loader_iter: break @@ -139,11 +256,51 @@ class TestLJSpeechDataset(unittest.TestCase): stop_target = data[5] item_idx = data[6] + neg_values = text_input[text_input < 0] + check_count = len(neg_values) + assert check_count == 0, \ + " !! Negative values in text_input: {}".format(check_count) + # TODO: more assertion here + assert mel_input.shape[0] == c.batch_size + assert mel_input.shape[2] == c.audio['num_mels'] + dataloader.dataset.sort_items() + assert frames[0] != dataloader.dataset.items[0] + + def test_padding_and_spec(self): + if ok_ljspeech: + dataloader, dataset = self._create_dataloader(1, 1, 0) + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + + # check mel_spec consistency + if item_idx[0].split('.')[-1] == 'npy': + wav = np.load(item_idx[0]) + else: + wav = self.ap.load_wav(item_idx[0]) + mel = self.ap.melspectrogram(wav) + mel_dl = mel_input[0].cpu().numpy() + assert (abs(mel.T).astype("float32") - abs( + mel_dl[:-1])).sum() == 0, ( + abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() + + # check mel-spec correctness + mel_spec = mel_input[0].cpu().numpy() + wav = self.ap.inv_mel_spectrogram(mel_spec.T) + self.ap.save_wav(wav, + OUTPATH + '/mel_inv_dataloader_cache.wav') + shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_cache.wav') + # check the last time step to be zero padded assert mel_input[0, -1].sum() == 0 assert mel_input[0, -2].sum() != 0 - assert linear_input[0, -1].sum() == 0 - assert linear_input[0, -2].sum() != 0 assert stop_target[0, -1] == 1 assert stop_target[0, -2] == 0 assert stop_target.sum() == 1 @@ -151,14 +308,7 @@ class TestLJSpeechDataset(unittest.TestCase): assert mel_lengths[0] == mel_input[0].shape[0] # Test for batch size 2 - dataloader = DataLoader( - dataset, - batch_size=2, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=False, - num_workers=c.num_loader_workers) - + dataloader, dataset = self._create_dataloader(2, 1, 0) for i, data in enumerate(dataloader): if i == self.max_loader_iter: break @@ -178,8 +328,6 @@ class TestLJSpeechDataset(unittest.TestCase): # check the first item in the batch assert mel_input[idx, -1].sum() == 0 assert mel_input[idx, -2].sum() != 0, mel_input - assert linear_input[idx, -1].sum() == 0 - assert linear_input[idx, -2].sum() != 0 assert stop_target[idx, -1] == 1 assert stop_target[idx, -2] == 0 assert stop_target[idx].sum() == 1 @@ -188,151 +336,202 @@ class TestLJSpeechDataset(unittest.TestCase): # check the second itme in the batch assert mel_input[1 - idx, -1].sum() == 0 - assert linear_input[1 - idx, -1].sum() == 0 assert stop_target[1 - idx, -1] == 1 assert len(mel_lengths.shape) == 1 # check batch conditions assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 - assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 -class TestKusalDataset(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestKusalDataset, self).__init__(*args, **kwargs) - self.max_loader_iter = 4 - self.ap = AudioProcessor( - sample_rate=c.sample_rate, - num_mels=c.num_mels, - min_level_db=c.min_level_db, - frame_shift_ms=c.frame_shift_ms, - frame_length_ms=c.frame_length_ms, - ref_level_db=c.ref_level_db, - num_freq=c.num_freq, - power=c.power, - preemphasis=c.preemphasis) +# class TestTTSDatasetMemory(unittest.TestCase): +# def __init__(self, *args, **kwargs): +# super(TestTTSDatasetMemory, self).__init__(*args, **kwargs) +# self.max_loader_iter = 4 +# self.c = load_config(os.path.join(c.data_path_cache, 'config.json')) +# self.ap = AudioProcessor(**c.audio) - def test_loader(self): - if ok_kusal: - dataset = Kusal.MyDataset( - os.path.join(c.data_path_Kusal), - os.path.join(c.data_path_Kusal, 'prompts.txt'), - c.r, - c.text_cleaner, - ap=self.ap, - min_seq_len=c.min_seq_len) +# def test_loader(self): +# if ok_ljspeech: +# dataset = TTSDatasetMemory.MyDataset( +# c.data_path_cache, +# 'tts_metadata.csv', +# c.r, +# c.text_cleaner, +# preprocessor=tts_cache, +# ap=self.ap, +# min_seq_len=c.min_seq_len) - dataloader = DataLoader( - dataset, - batch_size=2, - shuffle=True, - collate_fn=dataset.collate_fn, - drop_last=True, - num_workers=c.num_loader_workers) +# dataloader = DataLoader( +# dataset, +# batch_size=2, +# shuffle=True, +# collate_fn=dataset.collate_fn, +# drop_last=True, +# num_workers=c.num_loader_workers) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] +# for i, data in enumerate(dataloader): +# if i == self.max_loader_iter: +# break +# text_input = data[0] +# text_lengths = data[1] +# linear_input = data[2] +# mel_input = data[3] +# mel_lengths = data[4] +# stop_target = data[5] +# item_idx = data[6] - neg_values = text_input[text_input < 0] - check_count = len(neg_values) - assert check_count == 0, \ - " !! Negative values in text_input: {}".format(check_count) - # TODO: more assertion here - assert linear_input.shape[0] == c.batch_size - assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.num_mels +# neg_values = text_input[text_input < 0] +# check_count = len(neg_values) +# assert check_count == 0, \ +# " !! Negative values in text_input: {}".format(check_count) +# # check mel-spec shape +# assert mel_input.shape[0] == c.batch_size +# assert mel_input.shape[2] == c.audio['num_mels'] +# assert mel_input.max() <= self.ap.max_norm +# # check data range +# if self.ap.symmetric_norm: +# assert mel_input.max() <= self.ap.max_norm +# assert mel_input.min() >= -self.ap.max_norm +# assert mel_input.min() < 0 +# else: +# assert mel_input.max() <= self.ap.max_norm +# assert mel_input.min() >= 0 - def test_padding(self): - if ok_kusal: - dataset = Kusal.MyDataset( - os.path.join(c.data_path_Kusal), - os.path.join(c.data_path_Kusal, 'prompts.txt'), - 1, - c.text_cleaner, - ap=self.ap, - min_seq_len=c.min_seq_len) +# def test_batch_group_shuffle(self): +# if ok_ljspeech: +# dataset = TTSDatasetMemory.MyDataset( +# c.data_path_cache, +# 'tts_metadata.csv', +# c.r, +# c.text_cleaner, +# preprocessor=ljspeech, +# ap=self.ap, +# batch_group_size=16, +# min_seq_len=c.min_seq_len) - # Test for batch size 1 - dataloader = DataLoader( - dataset, - batch_size=1, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=True, - num_workers=c.num_loader_workers) +# dataloader = DataLoader( +# dataset, +# batch_size=2, +# shuffle=True, +# collate_fn=dataset.collate_fn, +# drop_last=True, +# num_workers=c.num_loader_workers) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] +# frames = dataset.items +# for i, data in enumerate(dataloader): +# if i == self.max_loader_iter: +# break +# text_input = data[0] +# text_lengths = data[1] +# linear_input = data[2] +# mel_input = data[3] +# mel_lengths = data[4] +# stop_target = data[5] +# item_idx = data[6] - # check the last time step to be zero padded - assert mel_input[0, -1].sum() == 0 - # assert mel_input[0, -2].sum() != 0 - assert linear_input[0, -1].sum() == 0 - # assert linear_input[0, -2].sum() != 0 - assert stop_target[0, -1] == 1 - assert stop_target[0, -2] == 0 - assert stop_target.sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[0] == mel_input[0].shape[0] +# neg_values = text_input[text_input < 0] +# check_count = len(neg_values) +# assert check_count == 0, \ +# " !! Negative values in text_input: {}".format(check_count) +# assert mel_input.shape[0] == c.batch_size +# assert mel_input.shape[2] == c.audio['num_mels'] +# dataloader.dataset.sort_items() +# assert frames[0] != dataloader.dataset.items[0] - # Test for batch size 2 - dataloader = DataLoader( - dataset, - batch_size=2, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=False, - num_workers=c.num_loader_workers) +# def test_padding_and_spec(self): +# if ok_ljspeech: +# dataset = TTSDatasetMemory.MyDataset( +# c.data_path_cache, +# 'tts_meta_data.csv', +# 1, +# c.text_cleaner, +# preprocessor=ljspeech, +# ap=self.ap, +# min_seq_len=c.min_seq_len) - for i, data in enumerate(dataloader): - if i == self.max_loader_iter: - break - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] - stop_target = data[5] - item_idx = data[6] +# # Test for batch size 1 +# dataloader = DataLoader( +# dataset, +# batch_size=1, +# shuffle=False, +# collate_fn=dataset.collate_fn, +# drop_last=True, +# num_workers=c.num_loader_workers) - if mel_lengths[0] > mel_lengths[1]: - idx = 0 - else: - idx = 1 +# for i, data in enumerate(dataloader): +# if i == self.max_loader_iter: +# break +# text_input = data[0] +# text_lengths = data[1] +# linear_input = data[2] +# mel_input = data[3] +# mel_lengths = data[4] +# stop_target = data[5] +# item_idx = data[6] - # check the first item in the batch - assert mel_input[idx, -1].sum() == 0 - assert mel_input[idx, -2].sum() != 0, mel_input - assert linear_input[idx, -1].sum() == 0 - assert linear_input[idx, -2].sum() != 0 - assert stop_target[idx, -1] == 1 - assert stop_target[idx, -2] == 0 - assert stop_target[idx].sum() == 1 - assert len(mel_lengths.shape) == 1 - assert mel_lengths[idx] == mel_input[idx].shape[0] +# # check mel_spec consistency +# if item_idx[0].split('.')[-1] == 'npy': +# wav = np.load(item_idx[0]) +# else: +# wav = self.ap.load_wav(item_idx[0]) +# mel = self.ap.melspectrogram(wav) +# mel_dl = mel_input[0].cpu().numpy() +# assert ( +# abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0 - # check the second itme in the batch - assert mel_input[1 - idx, -1].sum() == 0 - assert linear_input[1 - idx, -1].sum() == 0 - assert stop_target[1 - idx, -1] == 1 - assert len(mel_lengths.shape) == 1 +# # check mel-spec correctness +# mel_spec = mel_input[0].cpu().numpy() +# wav = self.ap.inv_mel_spectrogram(mel_spec.T) +# self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader_memo.wav') +# shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_memo.wav') - # check batch conditions - assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 - assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 +# # check the last time step to be zero padded +# assert mel_input[0, -1].sum() == 0 +# assert mel_input[0, -2].sum() != 0 +# assert stop_target[0, -1] == 1 +# assert stop_target[0, -2] == 0 +# assert stop_target.sum() == 1 +# assert len(mel_lengths.shape) == 1 +# assert mel_lengths[0] == mel_input[0].shape[0] + +# # Test for batch size 2 +# dataloader = DataLoader( +# dataset, +# batch_size=2, +# shuffle=False, +# collate_fn=dataset.collate_fn, +# drop_last=False, +# num_workers=c.num_loader_workers) + +# for i, data in enumerate(dataloader): +# if i == self.max_loader_iter: +# break +# text_input = data[0] +# text_lengths = data[1] +# linear_input = data[2] +# mel_input = data[3] +# mel_lengths = data[4] +# stop_target = data[5] +# item_idx = data[6] + +# if mel_lengths[0] > mel_lengths[1]: +# idx = 0 +# else: +# idx = 1 + +# # check the first item in the batch +# assert mel_input[idx, -1].sum() == 0 +# assert mel_input[idx, -2].sum() != 0, mel_input +# assert stop_target[idx, -1] == 1 +# assert stop_target[idx, -2] == 0 +# assert stop_target[idx].sum() == 1 +# assert len(mel_lengths.shape) == 1 +# assert mel_lengths[idx] == mel_input[idx].shape[0] + +# # check the second itme in the batch +# assert mel_input[1 - idx, -1].sum() == 0 +# assert stop_target[1 - idx, -1] == 1 +# assert len(mel_lengths.shape) == 1 + +# # check batch conditions +# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 diff --git a/tests/tacotron_tests.py b/tests/tacotron_tests.py index 37f36fa6..0e872a07 100644 --- a/tests/tacotron_tests.py +++ b/tests/tacotron_tests.py @@ -21,8 +21,8 @@ c = load_config(os.path.join(file_path, 'test_config.json')) class TacotronTrainTest(unittest.TestCase): def test_train_step(self): input = torch.randint(0, 24, (8, 128)).long().to(device) - mel_spec = torch.rand(8, 30, c.num_mels).to(device) - linear_spec = torch.rand(8, 30, c.num_freq).to(device) + mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device) @@ -35,7 +35,7 @@ class TacotronTrainTest(unittest.TestCase): criterion = L1LossMasked().to(device) criterion_st = nn.BCELoss().to(device) - model = Tacotron(c.embedding_size, c.num_freq, c.num_mels, + model = Tacotron(c.embedding_size, c.audio['num_freq'], c.audio['num_mels'], c.r).to(device) model.train() model_ref = copy.deepcopy(model) diff --git a/tests/test_config.json b/tests/test_config.json index af0d070d..05a8137d 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -1,16 +1,25 @@ { - "num_mels": 80, - "num_freq": 1025, - "sample_rate": 22050, - "frame_length_ms": 50, - "frame_shift_ms": 12.5, - "preemphasis": 0.97, - "min_level_db": -100, - "ref_level_db": 20, + "audio":{ + "audio_processor": "audio", // to use dictate different audio processors, if available. + "num_mels": 80, // size of the mel spec frame. + "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame. + "sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled. + "frame_length_ms": 50, // stft window length in ms. + "frame_shift_ms": 12.5, // stft window hop-lengh in ms. + "preemphasis": 0.97, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "min_level_db": -100, // normalization range + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + "power": 1.5, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 30,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + "signal_norm": true, // normalize the spec values in range [0, 1] + "symmetric_norm": true, // move normalization to range [-1, 1] + "clip_norm": true, // clip normalized values into the range. + "max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "mel_fmin": 95, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 7600 // maximum freq level for mel-spec. Tune for dataset!! + }, "hidden_size": 128, "embedding_size": 256, - "min_mel_freq": null, - "max_mel_freq": null, "text_cleaner": "english_cleaners", "epochs": 2000, @@ -21,16 +30,11 @@ "r": 5, "mk": 1.0, "priority_freq": false, - - - "griffin_lim_iters": 60, - "power": 1.5, - "num_loader_workers": 4, "save_step": 200, - "data_path_LJSpeech": "/home/erogol/Data/LJSpeech-1.1", - "data_path_Kusal": "/home/erogol/Data/Kusal", + "data_path": "/home/erogol/Data/LJSpeech-1.1/", + "data_path_cache": "/home/erogol/Data/LJSpeech-1.1/tts_cache/", "output_path": "result", "min_seq_len": 0, "log_dir": "/home/erogol/projects/TTS/logs/" diff --git a/train.py b/train.py index 327fd147..bbe55515 100644 --- a/train.py +++ b/train.py @@ -14,16 +14,16 @@ from torch.utils.data import DataLoader from tensorboardX import SummaryWriter from utils.generic_utils import ( - synthesis, remove_experiment_folder, create_experiment_folder, + remove_experiment_folder, create_experiment_folder, save_checkpoint, save_best_model, load_config, lr_decay, count_parameters, check_update, get_commit_hash, sequence_mask, AnnealLR) from utils.visual import plot_alignment, plot_spectrogram from models.tacotron import Tacotron from layers.losses import L1LossMasked from utils.audio import AudioProcessor +from utils.synthesis import synthesis torch.manual_seed(1) -# torch.set_num_threads(4) use_cuda = torch.cuda.is_available() @@ -36,7 +36,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, avg_stop_loss = 0 avg_step_time = 0 print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True) - n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) + n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq']) batch_n_iter = int(len(data_loader.dataset) / c.batch_size) for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -215,7 +215,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): "I'm sorry Dave. I'm afraid I can't do that.", "This cake is great. It's so delicious and moist." ] - n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) + n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq']) with torch.no_grad(): if data_loader is not None: for num_iter, data in enumerate(data_loader): @@ -277,6 +277,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): const_spec = linear_output[idx].data.cpu().numpy() gt_spec = linear_input[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy() + const_spec = plot_spectrogram(const_spec, ap) gt_spec = plot_spectrogram(gt_spec, ap) align_img = plot_alignment(align_img) @@ -319,8 +320,8 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): ap.griffin_lim_iters = 60 for idx, test_sentence in enumerate(test_sentences): try: - wav, linear_spec, alignments = synthesis(model, ap, test_sentence, - use_cuda, c.text_cleaner) + wav, alignment, linear_spec, stop_tokens = synthesis(model, test_sentence, c, + use_cuda, ap) file_path = os.path.join(AUDIO_PATH, str(current_step)) os.makedirs(file_path, exist_ok=True) @@ -330,7 +331,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): wav_name = 'TestSentences/{}'.format(idx) tb.add_audio( - wav_name, wav, current_step, sample_rate=c.sample_rate) + wav_name, wav, current_step, sample_rate=c.audio['sample_rate']) align_img = alignments[0].data.cpu().numpy() linear_spec = plot_spectrogram(linear_spec, ap) align_img = plot_alignment(align_img) @@ -345,28 +346,22 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): def main(args): - dataset = importlib.import_module('datasets.' + c.dataset) - Dataset = getattr(dataset, 'MyDataset') - audio = importlib.import_module('utils.' + c.audio_processor) + preprocessor = importlib.import_module('datasets.preprocess') + preprocessor = getattr(preprocessor, c.dataset.lower()) + MyDataset = importlib.import_module('datasets.'+c.data_loader) + MyDataset = getattr(MyDataset, "MyDataset") + audio = importlib.import_module('utils.' + c.audio['audio_processor']) AudioProcessor = getattr(audio, 'AudioProcessor') - ap = AudioProcessor( - sample_rate=c.sample_rate, - num_mels=c.num_mels, - min_level_db=c.min_level_db, - frame_shift_ms=c.frame_shift_ms, - frame_length_ms=c.frame_length_ms, - ref_level_db=c.ref_level_db, - num_freq=c.num_freq, - power=c.power, - preemphasis=c.preemphasis) + ap = AudioProcessor(**c.audio) # Setup the dataset - train_dataset = Dataset( + train_dataset = MyDataset( c.data_path, c.meta_file_train, c.r, c.text_cleaner, + preprocessor=preprocessor, ap=ap, batch_group_size=8*c.batch_size, min_seq_len=c.min_seq_len) @@ -381,8 +376,8 @@ def main(args): pin_memory=True) if c.run_eval: - val_dataset = Dataset( - c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap, batch_group_size=0) + val_dataset = MyDataset( + c.data_path, c.meta_file_val, c.r, c.text_cleaner, preprocessor=preprocessor, ap=ap, batch_group_size=0) val_loader = DataLoader( val_dataset, @@ -395,7 +390,7 @@ def main(args): else: val_loader = None - model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r) + model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r) print(" | > Num output units : {}".format(ap.num_freq), flush=True) optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0) @@ -431,7 +426,7 @@ def main(args): criterion.cuda() criterion_st.cuda() - scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps, last_epoch= (args.restore_step-1)) + scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) num_params = count_parameters(model) print(" | > Model has {} parameters".format(num_params), flush=True) @@ -454,7 +449,7 @@ def main(args): best_loss = save_best_model(model, optimizer, train_loss, best_loss, OUT_PATH, current_step, epoch) # shuffle batch groups - train_loader.dataset.sort_frames() + train_loader.dataset.sort_items() if __name__ == '__main__': @@ -477,8 +472,9 @@ if __name__ == '__main__': parser.add_argument( '--data_path', type=str, - default='', - help='data path to overrite config.json') + help='dataset path.', + default='' + ) args = parser.parse_args() # setup output paths and read configs @@ -491,7 +487,7 @@ if __name__ == '__main__': os.makedirs(AUDIO_PATH, exist_ok=True) shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json')) - if args.data_path != "": + if args.data_path != '': c.data_path = args.data_path # setup tensorboard diff --git a/utils/audio.py b/utils/audio.py index 061fefc0..a82eaeba 100644 --- a/utils/audio.py +++ b/utils/audio.py @@ -3,26 +3,34 @@ import librosa import pickle import copy import numpy as np -import scipy -from scipy import signal - -_mel_basis = None +from pprint import pprint +from scipy import signal, io class AudioProcessor(object): def __init__(self, - sample_rate, - num_mels, - min_level_db, - frame_shift_ms, - frame_length_ms, - ref_level_db, - num_freq, - power, - preemphasis, - griffin_lim_iters=None): + bits=None, + sample_rate=None, + num_mels=None, + min_level_db=None, + frame_shift_ms=None, + frame_length_ms=None, + ref_level_db=None, + num_freq=None, + power=None, + preemphasis=None, + signal_norm=None, + symmetric_norm=None, + max_norm=None, + mel_fmin=None, + mel_fmax=None, + clip_norm=True, + griffin_lim_iters=None, + **kwargs): print(" > Setting up Audio Processor...") + + self.bits = bits self.sample_rate = sample_rate self.num_mels = num_mels self.min_level_db = min_level_db @@ -33,33 +41,79 @@ class AudioProcessor(object): self.power = power self.preemphasis = preemphasis self.griffin_lim_iters = griffin_lim_iters + self.signal_norm = signal_norm + self.symmetric_norm = symmetric_norm + self.mel_fmin = 0 if mel_fmin is None else mel_fmin + self.mel_fmax = mel_fmax + self.max_norm = 1.0 if max_norm is None else float(max_norm) + self.clip_norm = clip_norm self.n_fft, self.hop_length, self.win_length = self._stft_parameters() if preemphasis == 0: print(" | > Preemphasis is deactive.") + print(" | > Audio Processor attributes.") + members = vars(self) + pprint(members) def save_wav(self, wav, path): wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) - # librosa.output.write_wav(path, wav_norm.astype(np.int16), self.sample_rate) - scipy.io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16)) + io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16)) def _linear_to_mel(self, spectrogram): - global _mel_basis - if _mel_basis is None: - _mel_basis = self._build_mel_basis() + _mel_basis = self._build_mel_basis() return np.dot(_mel_basis, spectrogram) + def _mel_to_linear(self, mel_spec): + inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) + return np.maximum(1e-10, np.dot(inv_mel_basis, mel_spec)) + def _build_mel_basis(self, ): n_fft = (self.num_freq - 1) * 2 + if self.mel_fmax is not None: + assert self.mel_fmax <= self.sample_rate // 2 return librosa.filters.mel( - self.sample_rate, n_fft, n_mels=self.num_mels) + self.sample_rate, + n_fft, + n_mels=self.num_mels, + fmin=self.mel_fmin, + fmax=self.mel_fmax) def _normalize(self, S): - return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) + """Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]""" + if self.signal_norm: + S_norm = ((S - self.min_level_db) / - self.min_level_db) + if self.symmetric_norm: + S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm + if self.clip_norm : + S_norm = np.clip(S_norm, -self.max_norm, self.max_norm) + return S_norm + else: + S_norm = self.max_norm * S_norm + if self.clip_norm: + S_norm = np.clip(S_norm, 0, self.max_norm) + return S_norm + else: + return S def _denormalize(self, S): - return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db + """denormalize values""" + S_denorm = S + if self.signal_norm: + if self.symmetric_norm: + if self.clip_norm: + S_denorm = np.clip(S_denorm, -self.max_norm, self.max_norm) + S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db + return S_denorm + else: + if self.clip_norm: + S_denorm = np.clip(S_denorm, 0, self.max_norm) + S_denorm = (S_denorm * -self.min_level_db / + self.max_norm) + self.min_level_db + return S_denorm + else: + return S def _stft_parameters(self, ): + """Compute necessary stft parameters with given time values""" n_fft = (self.num_freq - 1) * 2 hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate) @@ -92,8 +146,16 @@ class AudioProcessor(object): S = self._amp_to_db(np.abs(D)) - self.ref_level_db return self._normalize(S) + def melspectrogram(self, y): + if self.preemphasis != 0: + D = self._stft(self.apply_preemphasis(y)) + else: + D = self._stft(y) + S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db + return self._normalize(S) + def inv_spectrogram(self, spectrogram): - '''Converts spectrogram to waveform using librosa''' + """Converts spectrogram to waveform using librosa""" S = self._denormalize(spectrogram) S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear # Reconstruct phase @@ -102,6 +164,16 @@ class AudioProcessor(object): else: return self._griffin_lim(S**self.power) + def inv_mel_spectrogram(self, mel_spectrogram): + '''Converts mel spectrogram to waveform using librosa''' + D = self._denormalize(mel_spectrogram) + S = self._db_to_amp(D + self.ref_level_db) + S = self._mel_to_linear(S) # Convert back to linear + if self.preemphasis != 0: + return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) + else: + return self._griffin_lim(S**self.power) + def _griffin_lim(self, S): angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) S_complex = np.abs(S).astype(np.complex) @@ -111,20 +183,17 @@ class AudioProcessor(object): y = self._istft(S_complex * angles) return y - def melspectrogram(self, y): - if self.preemphasis != 0: - D = self._stft(self.apply_preemphasis(y)) - else: - D = self._stft(y) - S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db - return self._normalize(S) - def _stft(self, y): return librosa.stft( - y=y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length) + y=y, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + ) def _istft(self, y): - return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length) + return librosa.istft( + y, hop_length=self.hop_length, win_length=self.win_length) def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8): window_length = int(self.sample_rate * min_silence_sec) @@ -134,3 +203,37 @@ class AudioProcessor(object): if np.max(wav[x:x + window_length]) < threshold: return x + hop_length return len(wav) + + # WaveRNN repo specific functions + # def mulaw_encode(self, wav, qc): + # mu = qc - 1 + # wav_abs = np.minimum(np.abs(wav), 1.0) + # magnitude = np.log(1 + mu * wav_abs) / np.log(1. + mu) + # signal = np.sign(wav) * magnitude + # # Quantize signal to the specified number of levels. + # signal = (signal + 1) / 2 * mu + 0.5 + # return signal.astype(np.int32) + + # def mulaw_decode(self, wav, qc): + # """Recovers waveform from quantized values.""" + # mu = qc - 1 + # # Map values back to [-1, 1]. + # casted = wav.astype(np.float32) + # signal = 2 * (casted / mu) - 1 + # # Perform inverse of mu-law transformation. + # magnitude = (1 / mu) * ((1 + mu) ** abs(signal) - 1) + # return np.sign(signal) * magnitude + + def load_wav(self, filename, encode=False): + x, sr = librosa.load(filename, sr=self.sample_rate) + assert self.sample_rate == sr + return x + + def encode_16bits(self, x): + return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16) + + def quantize(self, x): + return (x + 1.) * (2**self.bits - 1) / 2 + + def dequantize(self, x): + return 2 * x / (2**self.bits - 1) - 1 diff --git a/utils/generic_utils.py b/utils/generic_utils.py index a1f72a3d..c56c3edf 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -1,4 +1,5 @@ import os +import re import sys import glob import time @@ -21,18 +22,23 @@ class AttrDict(dict): def load_config(config_path): config = AttrDict() - config.update(json.load(open(config_path, "r"))) + with open(config_path, "r") as f: + input_str = f.read() + input_str = re.sub(r'\\\n', '', input_str) + input_str = re.sub(r'//.*\n', '\n', input_str) + data = json.loads(input_str) + config.update(data) return config def get_commit_hash(): """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" - try: - subprocess.check_output(['git', 'diff-index', '--quiet', - 'HEAD']) # Verify client is clean - except: - raise RuntimeError( - " !! Commit before training to get the commit hash.") + # try: + # subprocess.check_output(['git', 'diff-index', '--quiet', + # 'HEAD']) # Verify client is clean + # except: + # raise RuntimeError( + # " !! Commit before training to get the commit hash.") commit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().strip() print(' > Git Hash: {}'.format(commit)) @@ -177,15 +183,3 @@ def sequence_mask(sequence_length, max_len=None): seq_length_expand = (sequence_length.unsqueeze(1) .expand_as(seq_range_expand)) return seq_range_expand < seq_length_expand - - -def synthesis(model, ap, text, use_cuda, text_cleaner): - text_cleaner = [text_cleaner] - seq = np.array(text_to_sequence(text, text_cleaner)) - chars_var = torch.from_numpy(seq).unsqueeze(0) - if use_cuda: - chars_var = chars_var.cuda().long() - _, linear_out, alignments, _ = model.forward(chars_var) - linear_out = linear_out[0].data.cpu().numpy() - wav = ap.inv_spectrogram(linear_out.T) - return wav, linear_out, alignments diff --git a/utils/synthesis.py b/utils/synthesis.py new file mode 100644 index 00000000..2531473a --- /dev/null +++ b/utils/synthesis.py @@ -0,0 +1,23 @@ +import io +import time +import librosa +import torch +import numpy as np +from .text import text_to_sequence +from .visual import visualize +from matplotlib import pylab as plt + + +def synthesis(m, s, CONFIG, use_cuda, ap): + """ Given the text, synthesising the audio """ + text_cleaner = [CONFIG.text_cleaner] + seq = np.array(text_to_sequence(s, text_cleaner)) + chars_var = torch.from_numpy(seq).unsqueeze(0) + if use_cuda: + chars_var = chars_var.cuda() + mel_spec, linear_spec, alignments, stop_tokens = m.forward(chars_var.long()) + linear_spec = linear_spec[0].data.cpu().numpy() + alignment = alignments[0].cpu().data.numpy() + wav = ap.inv_spectrogram(linear_spec.T) + # wav = wav[:ap.find_endpoint(wav)] + return wav, alignment, linear_spec, stop_tokens \ No newline at end of file diff --git a/utils/visual.py b/utils/visual.py index 8545ffe5..e61123d8 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -1,4 +1,5 @@ import numpy as np +import librosa import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt @@ -14,6 +15,7 @@ def plot_alignment(alignment, info=None): xlabel += '\n\n' + info plt.xlabel(xlabel) plt.ylabel('Encoder timestep') + # plt.yticks(range(len(text)), list(text)) plt.tight_layout() return fig @@ -25,3 +27,28 @@ def plot_spectrogram(linear_output, audio): plt.colorbar() plt.tight_layout() return fig + + +def visualize(alignment, spectrogram, stop_tokens, text, hop_length, CONFIG): + label_fontsize = 16 + plt.figure(figsize=(16, 32)) + + plt.subplot(3, 1, 1) + plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None) + plt.xlabel("Decoder timestamp", fontsize=label_fontsize) + plt.ylabel("Encoder timestamp", fontsize=label_fontsize) + plt.yticks(range(len(text)), list(text)) + plt.colorbar() + + stop_tokens = stop_tokens.squeeze().detach().to('cpu').numpy() + plt.subplot(3, 1, 2) + plt.plot(range(len(stop_tokens)), list(stop_tokens)) + + plt.subplot(3, 1, 3) + librosa.display.specshow(spectrogram.T, sr=CONFIG.audio['sample_rate'], + hop_length=hop_length, x_axis="time", y_axis="linear") + plt.xlabel("Time", fontsize=label_fontsize) + plt.ylabel("Hz", fontsize=label_fontsize) + + plt.tight_layout() + plt.colorbar()