From e15673892026c65fbcea55ca507fa25116147639 Mon Sep 17 00:00:00 2001 From: Eren G Date: Wed, 25 Jul 2018 19:14:07 +0200 Subject: [PATCH] adding cache loader --- .compute | 1 + config.json | 3 +- datasets/LJSpeech.py | 13 ++-- datasets/LJSpeechCached.py | 145 +++++++++++++++++++++++++++++++++++++ extract_feats.py | 88 ++++++++++++++++++++++ train.py | 51 +++++++------ 6 files changed, 271 insertions(+), 30 deletions(-) create mode 100644 datasets/LJSpeechCached.py create mode 100644 extract_feats.py diff --git a/.compute b/.compute index 23b24a62..98303c6c 100644 --- a/.compute +++ b/.compute @@ -3,4 +3,5 @@ # ls /snakepit/jobs/650/keep/ source /snakepit/jobs/650/keep/venv/bin/activate # source /snakepit/jobs/560/tmp/venv/bin/activate +python extract_feats.py --data_path /snakepit/shared/data/keithito/LJSpeech-1.1/wavs --out_path /snakepit/shared/data/keithito/LJSpeech-1.1/loader_data/ --config config.json --num_proc 8 python train.py --config_path config.json --debug true diff --git a/config.json b/config.json index e2b0cb4d..4fdd63aa 100644 --- a/config.json +++ b/config.json @@ -31,7 +31,8 @@ "run_eval": false, "data_path": "/snakepit/shared/data/keithito/LJSpeech-1.1/", "meta_file_train": "metadata.csv", - "meta_file_val": "metadata_val.csv", + "meta_file_val": null, + "dataset": "LJSpeechCached", "min_seq_len": 0, "output_path": "experiments/" } diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index f8872fcb..69aff0e0 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -10,14 +10,15 @@ from utils.data import (prepare_data, pad_per_step, prepare_tensor, prepare_stop_target) -class LJSpeechDataset(Dataset): +class MyDataset(Dataset): - def __init__(self, csv_file, root_dir, outputs_per_step, + def __init__(self, root_dir, csv_file, outputs_per_step, text_cleaner, ap, min_seq_len=0): - - with open(csv_file, "r", encoding="utf8") as f: - self.frames = [line.split('|') for line in f] self.root_dir = root_dir + self.wav_dir = os.path.join(root_dir, 'wavs') + self.csv_dir = os.path.join(root_dir, csv_file) + with open(self.csv_dir, "r", encoding="utf8") as f: + self.frames = [line.split('|') for line in f] self.outputs_per_step = outputs_per_step self.sample_rate = ap.sample_rate self.cleaners = text_cleaner @@ -59,7 +60,7 @@ class LJSpeechDataset(Dataset): return len(self.frames) def __getitem__(self, idx): - wav_name = os.path.join(self.root_dir, + wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav' text = self.frames[idx][1] text = np.asarray(text_to_sequence( diff --git a/datasets/LJSpeechCached.py b/datasets/LJSpeechCached.py new file mode 100644 index 00000000..c0018493 --- /dev/null +++ b/datasets/LJSpeechCached.py @@ -0,0 +1,145 @@ +import os +import numpy as np +import collections +import librosa +import torch +from torch.utils.data import Dataset + +from utils.text import text_to_sequence +from utils.data import (prepare_data, pad_per_step, + prepare_tensor, prepare_stop_target) + + +class MyDataset(Dataset): + + def __init__(self, root_dir, csv_file, outputs_per_step, + text_cleaner, ap, min_seq_len=0): + self.root_dir = root_dir + self.wav_dir = os.path.join(root_dir, 'wavs') + self.feat_dir = os.path.join(root_dir, 'loader_data') + self.csv_dir = os.path.join(root_dir, csv_file) + with open(self.csv_dir, "r", encoding="utf8") as f: + self.frames = [line.split('|') for line in f] + self.outputs_per_step = outputs_per_step + self.sample_rate = ap.sample_rate + self.cleaners = text_cleaner + self.min_seq_len = min_seq_len + self.items = [None] * len(self.frames) + print(" > Reading LJSpeech from - {}".format(root_dir)) + print(" | > Number of instances : {}".format(len(self.frames))) + self._sort_frames() + + def load_wav(self, filename): + try: + audio = librosa.core.load(filename, sr=self.sample_rate) + return audio + except RuntimeError as e: + print(" !! Cannot read file : {}".format(filename)) + + def load_np(self, filename): + data = np.load(filename).astype('float32') + return data + + def _sort_frames(self): + r"""Sort sequences in ascending order""" + lengths = np.array([len(ins[1]) for ins in self.frames]) + + print(" | > Max length sequence {}".format(np.max(lengths))) + print(" | > Min length sequence {}".format(np.min(lengths))) + print(" | > Avg length sequence {}".format(np.mean(lengths))) + + idxs = np.argsort(lengths) + new_frames = [] + ignored = [] + for i, idx in enumerate(idxs): + length = lengths[idx] + if length < self.min_seq_len: + ignored.append(idx) + else: + new_frames.append(self.frames[idx]) + print(" | > {} instances are ignored by min_seq_len ({})".format( + len(ignored), self.min_seq_len)) + self.frames = new_frames + + def __len__(self): + return len(self.frames) + + def __getitem__(self, idx): + if self.items[idx] is None: + wav_name = os.path.join(self.wav_dir, + self.frames[idx][0]) + '.wav' + mel_name = os.path.join(self.feat_dir, + self.frames[idx][0]) + '.mel.npy' + linear_name = os.path.join(self.feat_dir, + self.frames[idx][0]) + '.linear.npy' + text = self.frames[idx][1] + text = np.asarray(text_to_sequence( + text, [self.cleaners]), dtype=np.int32) + wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) + mel = self.load_np(mel_name) + linear = self.load_np(linear_name) + sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0], + 'mel':mel, 'linear': linear} + self.items[idx] = sample + else: + sample = self.items[idx] + return sample + + def collate_fn(self, batch): + r""" + Perform preprocessing and create a final data batch: + 1. PAD sequences with the longest sequence in the batch + 2. Convert Audio signal to Spectrograms. + 3. PAD sequences that can be divided by r. + 4. Convert Numpy to Torch tensors. + """ + + # Puts each data field into a tensor with outer dimension batch size + if isinstance(batch[0], collections.Mapping): + keys = list() + + wav = [d['wav'] for d in batch] + item_idxs = [d['item_idx'] for d in batch] + text = [d['text'] for d in batch] + mel = [d['mel'] for d in batch] + linear = [d['linear'] for d in batch] + + text_lenghts = np.array([len(x) for x in text]) + max_text_len = np.max(text_lenghts) + mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame + + # compute 'stop token' targets + stop_targets = [np.array([0.]*(mel_len-1)) + for mel_len in mel_lengths] + + # PAD stop targets + stop_targets = prepare_stop_target( + stop_targets, self.outputs_per_step) + + # PAD sequences with largest length of the batch + text = prepare_data(text).astype(np.int32) + wav = prepare_data(wav) + + # PAD features with largest length + a zero frame + linear = prepare_tensor(linear, self.outputs_per_step) + mel = prepare_tensor(mel, self.outputs_per_step) + assert mel.shape[2] == linear.shape[2] + timesteps = mel.shape[2] + + # B x T x D + linear = linear.transpose(0, 2, 1) + mel = mel.transpose(0, 2, 1) + + # convert things to pytorch + text_lenghts = torch.LongTensor(text_lenghts) + text = torch.LongTensor(text) + linear = torch.FloatTensor(linear) + mel = torch.FloatTensor(mel) + mel_lengths = torch.LongTensor(mel_lengths) + stop_targets = torch.FloatTensor(stop_targets) + + return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] + + raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ + found {}" + .format(type(batch[0])))) diff --git a/extract_feats.py b/extract_feats.py new file mode 100644 index 00000000..4d143d7d --- /dev/null +++ b/extract_feats.py @@ -0,0 +1,88 @@ +''' +Extract spectrograms and save them to file for training +''' +import os +import sys +import time +import glob +import argparse +import librosa +import numpy as np +import tqdm +from utils.audio import AudioProcessor +from utils.generic_utils import load_config + +from multiprocessing import Pool + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--data_path', type=str, + help='Data folder.') + parser.add_argument('--out_path', type=str, + help='Output folder.') + parser.add_argument('--config', type=str, + help='conf.json file for run settings.') + parser.add_argument("--num_proc", type=int, default=8, + help="number of processes.") + args = parser.parse_args() + + print(" > Input path: ", DATA_PATH) + print(" > Output path: ", OUT_PATH) + + DATA_PATH = args.data_path + OUT_PATH = args.out_path + CONFIG = load_config(args.config) + ap = AudioProcessor(sample_rate = CONFIG.sample_rate, + num_mels = CONFIG.num_mels, + min_level_db = CONFIG.min_level_db, + frame_shift_ms = CONFIG.frame_shift_ms, + frame_length_ms = CONFIG.frame_length_ms, + ref_level_db = CONFIG.ref_level_db, + num_freq = CONFIG.num_freq, + power = CONFIG.power, + min_mel_freq = CONFIG.min_mel_freq, + max_mel_freq = CONFIG.max_mel_freq) + + def extract_mel(file_path): + # x, fs = sf.read(file_path) + x, fs = librosa.load(file_path, CONFIG.sample_rate) + mel = ap.melspectrogram(x.astype('float32')) + linear = ap.spectrogram(x.astype('float32')) + file_name = os.path.basename(file_path).replace(".wav","") + mel_file = file_name + ".mel" + linear_file = file_name + ".linear" + np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False) + np.save(os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False) + mel_len = mel.shape[1] + linear_len = linear.shape[1] + wav_len = x.shape[0] + print(" > " + file_path, flush=True) + return file_path, mel_file, linear_file, str(wav_len), str(mel_len), str(linear_len) + + glob_path = os.path.join(DATA_PATH, "*.wav") + print(" > Reading wav: {}".format(glob_path)) + file_names = glob.glob(glob_path, recursive=True) + + if __name__ == "__main__": + print(" > Number of files: %i"%(len(file_names))) + if not os.path.exists(OUT_PATH): + os.makedirs(OUT_PATH) + print(" > A new folder created at {}".format(OUT_PATH)) + + r = [] + if args.num_proc > 1: + print(" > Using {} processes.".format(args.num_proc)) + with Pool(args.num_proc) as p: + r = list(tqdm.tqdm(p.imap(extract_mel, file_names), total=len(file_names))) + else: + print(" > Using single process run.") + for file_name in file_names: + print(" > ", file_name) + r.append(extract_mel(file_name)) + + file_path = os.path.join(OUT_PATH, "meta_fftnet.csv") + file = open(file_path, "w") + for line in r: + line = ", ".join(line) + file.write(line+'\n') + file.close() diff --git a/train.py b/train.py index 697a233d..c93607df 100644 --- a/train.py +++ b/train.py @@ -23,7 +23,6 @@ from utils.generic_utils import (synthesis, remove_experiment_folder, save_best_model, load_config, lr_decay, count_parameters, check_update, get_commit_hash) from utils.visual import plot_alignment, plot_spectrogram -from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron from layers.losses import L1LossMasked from utils.audio import AudioProcessor @@ -40,7 +39,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, avg_linear_loss = 0 avg_mel_loss = 0 avg_stop_loss = 0 - print(" | > Epoch {}/{}".format(epoch, c.epochs)) + print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True) n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -100,7 +99,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, grad_norm, skip_flag = check_update(model, 0.5, 100) if skip_flag: optimizer.zero_grad() - print(" | > Iteration skipped!!") + print(" | > Iteration skipped!!", flush=True) continue optimizer.step() @@ -126,7 +125,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, stop_loss.item(), grad_norm.item(), grad_norm_st.item(), - step_time)) + step_time), flush=True) avg_linear_loss += linear_loss.item() avg_mel_loss += mel_loss.item() @@ -185,7 +184,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, avg_linear_loss, avg_mel_loss, avg_stop_loss, - epoch_time)) + epoch_time), flush=True) # Plot Training Epoch Stats tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step) @@ -320,6 +319,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): def main(args): + dataset = importlib.import_module('datasets.'+c.dataset) + Dataset = getattr(dataset, 'MyDataset') + ap = AudioProcessor(sample_rate = c.sample_rate, num_mels = c.num_mels, min_level_db = c.min_level_db, @@ -332,13 +334,13 @@ def main(args): max_mel_freq = c.max_mel_freq) # Setup the dataset - train_dataset = LJSpeechDataset(os.path.join(c.data_path, c.meta_file_train), - os.path.join(c.data_path, 'wavs'), - c.r, - c.text_cleaner, - ap = ap, - min_seq_len=c.min_seq_len - ) + train_dataset = Dataset(c.data_path, + c.meta_file_train, + c.r, + c.text_cleaner, + ap = ap, + min_seq_len=c.min_seq_len + ) train_loader = DataLoader(train_dataset, batch_size=c.batch_size, shuffle=False, collate_fn=train_dataset.collate_fn, @@ -346,12 +348,12 @@ def main(args): pin_memory=True) if c.run_eval: - val_dataset = LJSpeechDataset(os.path.join(c.data_path, c.meta_file_val), - os.path.join(c.data_path, 'wavs'), - c.r, - c.text_cleaner, - ap = ap - ) + val_dataset = Dataset(c.data_path, + c.meta_file_val, + c.r, + c.text_cleaner, + ap = ap + ) val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size, shuffle=False, collate_fn=val_dataset.collate_fn, @@ -374,6 +376,10 @@ def main(args): if args.restore_path: checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) + if use_cuda: + model = nn.DataParallel(model.cuda()) + criterion.cuda() + criterion_st.cuda() optimizer.load_state_dict(checkpoint['optimizer']) optimizer_st.load_state_dict(checkpoint['optimizer_st']) for state in optimizer.state.values(): @@ -387,11 +393,10 @@ def main(args): else: args.restore_step = 0 print("\n > Starting a new training") - - if use_cuda: - model = nn.DataParallel(model.cuda()) - criterion.cuda() - criterion_st.cuda() + if use_cuda: + model = nn.DataParallel(model.cuda()) + criterion.cuda() + criterion_st.cuda() num_params = count_parameters(model) print(" | > Model has {} parameters".format(num_params))