adding cache loader

pull/10/head
Eren G 2018-07-25 19:14:07 +02:00
parent bb8aaadc1f
commit e156738920
6 changed files with 271 additions and 30 deletions

View File

@ -3,4 +3,5 @@
# ls /snakepit/jobs/650/keep/
source /snakepit/jobs/650/keep/venv/bin/activate
# source /snakepit/jobs/560/tmp/venv/bin/activate
python extract_feats.py --data_path /snakepit/shared/data/keithito/LJSpeech-1.1/wavs --out_path /snakepit/shared/data/keithito/LJSpeech-1.1/loader_data/ --config config.json --num_proc 8
python train.py --config_path config.json --debug true

View File

@ -31,7 +31,8 @@
"run_eval": false,
"data_path": "/snakepit/shared/data/keithito/LJSpeech-1.1/",
"meta_file_train": "metadata.csv",
"meta_file_val": "metadata_val.csv",
"meta_file_val": null,
"dataset": "LJSpeechCached",
"min_seq_len": 0,
"output_path": "experiments/"
}

View File

@ -10,14 +10,15 @@ from utils.data import (prepare_data, pad_per_step,
prepare_tensor, prepare_stop_target)
class LJSpeechDataset(Dataset):
class MyDataset(Dataset):
def __init__(self, csv_file, root_dir, outputs_per_step,
def __init__(self, root_dir, csv_file, outputs_per_step,
text_cleaner, ap, min_seq_len=0):
with open(csv_file, "r", encoding="utf8") as f:
self.frames = [line.split('|') for line in f]
self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wavs')
self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f:
self.frames = [line.split('|') for line in f]
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
@ -59,7 +60,7 @@ class LJSpeechDataset(Dataset):
return len(self.frames)
def __getitem__(self, idx):
wav_name = os.path.join(self.root_dir,
wav_name = os.path.join(self.wav_dir,
self.frames[idx][0]) + '.wav'
text = self.frames[idx][1]
text = np.asarray(text_to_sequence(

145
datasets/LJSpeechCached.py Normal file
View File

@ -0,0 +1,145 @@
import os
import numpy as np
import collections
import librosa
import torch
from torch.utils.data import Dataset
from utils.text import text_to_sequence
from utils.data import (prepare_data, pad_per_step,
prepare_tensor, prepare_stop_target)
class MyDataset(Dataset):
def __init__(self, root_dir, csv_file, outputs_per_step,
text_cleaner, ap, min_seq_len=0):
self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wavs')
self.feat_dir = os.path.join(root_dir, 'loader_data')
self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f:
self.frames = [line.split('|') for line in f]
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.min_seq_len = min_seq_len
self.items = [None] * len(self.frames)
print(" > Reading LJSpeech from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames()
def load_wav(self, filename):
try:
audio = librosa.core.load(filename, sr=self.sample_rate)
return audio
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def load_np(self, filename):
data = np.load(filename).astype('float32')
return data
def _sort_frames(self):
r"""Sort sequences in ascending order"""
lengths = np.array([len(ins[1]) for ins in self.frames])
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths)))
idxs = np.argsort(lengths)
new_frames = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len:
ignored.append(idx)
else:
new_frames.append(self.frames[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len))
self.frames = new_frames
def __len__(self):
return len(self.frames)
def __getitem__(self, idx):
if self.items[idx] is None:
wav_name = os.path.join(self.wav_dir,
self.frames[idx][0]) + '.wav'
mel_name = os.path.join(self.feat_dir,
self.frames[idx][0]) + '.mel.npy'
linear_name = os.path.join(self.feat_dir,
self.frames[idx][0]) + '.linear.npy'
text = self.frames[idx][1]
text = np.asarray(text_to_sequence(
text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
mel = self.load_np(mel_name)
linear = self.load_np(linear_name)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0],
'mel':mel, 'linear': linear}
self.items[idx] = sample
else:
sample = self.items[idx]
return sample
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. PAD sequences with the longest sequence in the batch
2. Convert Audio signal to Spectrograms.
3. PAD sequences that can be divided by r.
4. Convert Numpy to Torch tensors.
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
keys = list()
wav = [d['wav'] for d in batch]
item_idxs = [d['item_idx'] for d in batch]
text = [d['text'] for d in batch]
mel = [d['mel'] for d in batch]
linear = [d['linear'] for d in batch]
text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts)
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1))
for mel_len in mel_lengths]
# PAD stop targets
stop_targets = prepare_stop_target(
stop_targets, self.outputs_per_step)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
# PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
timesteps = mel.shape[2]
# B x T x D
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear)
mel = torch.FloatTensor(mel)
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}"
.format(type(batch[0]))))

88
extract_feats.py Normal file
View File

@ -0,0 +1,88 @@
'''
Extract spectrograms and save them to file for training
'''
import os
import sys
import time
import glob
import argparse
import librosa
import numpy as np
import tqdm
from utils.audio import AudioProcessor
from utils.generic_utils import load_config
from multiprocessing import Pool
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str,
help='Data folder.')
parser.add_argument('--out_path', type=str,
help='Output folder.')
parser.add_argument('--config', type=str,
help='conf.json file for run settings.')
parser.add_argument("--num_proc", type=int, default=8,
help="number of processes.")
args = parser.parse_args()
print(" > Input path: ", DATA_PATH)
print(" > Output path: ", OUT_PATH)
DATA_PATH = args.data_path
OUT_PATH = args.out_path
CONFIG = load_config(args.config)
ap = AudioProcessor(sample_rate = CONFIG.sample_rate,
num_mels = CONFIG.num_mels,
min_level_db = CONFIG.min_level_db,
frame_shift_ms = CONFIG.frame_shift_ms,
frame_length_ms = CONFIG.frame_length_ms,
ref_level_db = CONFIG.ref_level_db,
num_freq = CONFIG.num_freq,
power = CONFIG.power,
min_mel_freq = CONFIG.min_mel_freq,
max_mel_freq = CONFIG.max_mel_freq)
def extract_mel(file_path):
# x, fs = sf.read(file_path)
x, fs = librosa.load(file_path, CONFIG.sample_rate)
mel = ap.melspectrogram(x.astype('float32'))
linear = ap.spectrogram(x.astype('float32'))
file_name = os.path.basename(file_path).replace(".wav","")
mel_file = file_name + ".mel"
linear_file = file_name + ".linear"
np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False)
np.save(os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False)
mel_len = mel.shape[1]
linear_len = linear.shape[1]
wav_len = x.shape[0]
print(" > " + file_path, flush=True)
return file_path, mel_file, linear_file, str(wav_len), str(mel_len), str(linear_len)
glob_path = os.path.join(DATA_PATH, "*.wav")
print(" > Reading wav: {}".format(glob_path))
file_names = glob.glob(glob_path, recursive=True)
if __name__ == "__main__":
print(" > Number of files: %i"%(len(file_names)))
if not os.path.exists(OUT_PATH):
os.makedirs(OUT_PATH)
print(" > A new folder created at {}".format(OUT_PATH))
r = []
if args.num_proc > 1:
print(" > Using {} processes.".format(args.num_proc))
with Pool(args.num_proc) as p:
r = list(tqdm.tqdm(p.imap(extract_mel, file_names), total=len(file_names)))
else:
print(" > Using single process run.")
for file_name in file_names:
print(" > ", file_name)
r.append(extract_mel(file_name))
file_path = os.path.join(OUT_PATH, "meta_fftnet.csv")
file = open(file_path, "w")
for line in r:
line = ", ".join(line)
file.write(line+'\n')
file.close()

View File

@ -23,7 +23,6 @@ from utils.generic_utils import (synthesis, remove_experiment_folder,
save_best_model, load_config, lr_decay,
count_parameters, check_update, get_commit_hash)
from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron
from layers.losses import L1LossMasked
from utils.audio import AudioProcessor
@ -40,7 +39,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_linear_loss = 0
avg_mel_loss = 0
avg_stop_loss = 0
print(" | > Epoch {}/{}".format(epoch, c.epochs))
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -100,7 +99,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
grad_norm, skip_flag = check_update(model, 0.5, 100)
if skip_flag:
optimizer.zero_grad()
print(" | > Iteration skipped!!")
print(" | > Iteration skipped!!", flush=True)
continue
optimizer.step()
@ -126,7 +125,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
stop_loss.item(),
grad_norm.item(),
grad_norm_st.item(),
step_time))
step_time), flush=True)
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
@ -185,7 +184,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_linear_loss,
avg_mel_loss,
avg_stop_loss,
epoch_time))
epoch_time), flush=True)
# Plot Training Epoch Stats
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
@ -320,6 +319,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
def main(args):
dataset = importlib.import_module('datasets.'+c.dataset)
Dataset = getattr(dataset, 'MyDataset')
ap = AudioProcessor(sample_rate = c.sample_rate,
num_mels = c.num_mels,
min_level_db = c.min_level_db,
@ -332,13 +334,13 @@ def main(args):
max_mel_freq = c.max_mel_freq)
# Setup the dataset
train_dataset = LJSpeechDataset(os.path.join(c.data_path, c.meta_file_train),
os.path.join(c.data_path, 'wavs'),
c.r,
c.text_cleaner,
ap = ap,
min_seq_len=c.min_seq_len
)
train_dataset = Dataset(c.data_path,
c.meta_file_train,
c.r,
c.text_cleaner,
ap = ap,
min_seq_len=c.min_seq_len
)
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
shuffle=False, collate_fn=train_dataset.collate_fn,
@ -346,12 +348,12 @@ def main(args):
pin_memory=True)
if c.run_eval:
val_dataset = LJSpeechDataset(os.path.join(c.data_path, c.meta_file_val),
os.path.join(c.data_path, 'wavs'),
c.r,
c.text_cleaner,
ap = ap
)
val_dataset = Dataset(c.data_path,
c.meta_file_val,
c.r,
c.text_cleaner,
ap = ap
)
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
shuffle=False, collate_fn=val_dataset.collate_fn,
@ -374,6 +376,10 @@ def main(args):
if args.restore_path:
checkpoint = torch.load(args.restore_path)
model.load_state_dict(checkpoint['model'])
if use_cuda:
model = nn.DataParallel(model.cuda())
criterion.cuda()
criterion_st.cuda()
optimizer.load_state_dict(checkpoint['optimizer'])
optimizer_st.load_state_dict(checkpoint['optimizer_st'])
for state in optimizer.state.values():
@ -387,11 +393,10 @@ def main(args):
else:
args.restore_step = 0
print("\n > Starting a new training")
if use_cuda:
model = nn.DataParallel(model.cuda())
criterion.cuda()
criterion_st.cuda()
if use_cuda:
model = nn.DataParallel(model.cuda())
criterion.cuda()
criterion_st.cuda()
num_params = count_parameters(model)
print(" | > Model has {} parameters".format(num_params))