mirror of https://github.com/coqui-ai/TTS.git
adding cache loader
parent
bb8aaadc1f
commit
e156738920
1
.compute
1
.compute
|
@ -3,4 +3,5 @@
|
|||
# ls /snakepit/jobs/650/keep/
|
||||
source /snakepit/jobs/650/keep/venv/bin/activate
|
||||
# source /snakepit/jobs/560/tmp/venv/bin/activate
|
||||
python extract_feats.py --data_path /snakepit/shared/data/keithito/LJSpeech-1.1/wavs --out_path /snakepit/shared/data/keithito/LJSpeech-1.1/loader_data/ --config config.json --num_proc 8
|
||||
python train.py --config_path config.json --debug true
|
||||
|
|
|
@ -31,7 +31,8 @@
|
|||
"run_eval": false,
|
||||
"data_path": "/snakepit/shared/data/keithito/LJSpeech-1.1/",
|
||||
"meta_file_train": "metadata.csv",
|
||||
"meta_file_val": "metadata_val.csv",
|
||||
"meta_file_val": null,
|
||||
"dataset": "LJSpeechCached",
|
||||
"min_seq_len": 0,
|
||||
"output_path": "experiments/"
|
||||
}
|
||||
|
|
|
@ -10,14 +10,15 @@ from utils.data import (prepare_data, pad_per_step,
|
|||
prepare_tensor, prepare_stop_target)
|
||||
|
||||
|
||||
class LJSpeechDataset(Dataset):
|
||||
class MyDataset(Dataset):
|
||||
|
||||
def __init__(self, csv_file, root_dir, outputs_per_step,
|
||||
def __init__(self, root_dir, csv_file, outputs_per_step,
|
||||
text_cleaner, ap, min_seq_len=0):
|
||||
|
||||
with open(csv_file, "r", encoding="utf8") as f:
|
||||
self.frames = [line.split('|') for line in f]
|
||||
self.root_dir = root_dir
|
||||
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||
self.frames = [line.split('|') for line in f]
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
|
@ -59,7 +60,7 @@ class LJSpeechDataset(Dataset):
|
|||
return len(self.frames)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
wav_name = os.path.join(self.root_dir,
|
||||
wav_name = os.path.join(self.wav_dir,
|
||||
self.frames[idx][0]) + '.wav'
|
||||
text = self.frames[idx][1]
|
||||
text = np.asarray(text_to_sequence(
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import collections
|
||||
import librosa
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from utils.text import text_to_sequence
|
||||
from utils.data import (prepare_data, pad_per_step,
|
||||
prepare_tensor, prepare_stop_target)
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
|
||||
def __init__(self, root_dir, csv_file, outputs_per_step,
|
||||
text_cleaner, ap, min_seq_len=0):
|
||||
self.root_dir = root_dir
|
||||
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||
self.feat_dir = os.path.join(root_dir, 'loader_data')
|
||||
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||
self.frames = [line.split('|') for line in f]
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.min_seq_len = min_seq_len
|
||||
self.items = [None] * len(self.frames)
|
||||
print(" > Reading LJSpeech from - {}".format(root_dir))
|
||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||
self._sort_frames()
|
||||
|
||||
def load_wav(self, filename):
|
||||
try:
|
||||
audio = librosa.core.load(filename, sr=self.sample_rate)
|
||||
return audio
|
||||
except RuntimeError as e:
|
||||
print(" !! Cannot read file : {}".format(filename))
|
||||
|
||||
def load_np(self, filename):
|
||||
data = np.load(filename).astype('float32')
|
||||
return data
|
||||
|
||||
def _sort_frames(self):
|
||||
r"""Sort sequences in ascending order"""
|
||||
lengths = np.array([len(ins[1]) for ins in self.frames])
|
||||
|
||||
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||
|
||||
idxs = np.argsort(lengths)
|
||||
new_frames = []
|
||||
ignored = []
|
||||
for i, idx in enumerate(idxs):
|
||||
length = lengths[idx]
|
||||
if length < self.min_seq_len:
|
||||
ignored.append(idx)
|
||||
else:
|
||||
new_frames.append(self.frames[idx])
|
||||
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||
len(ignored), self.min_seq_len))
|
||||
self.frames = new_frames
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frames)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.items[idx] is None:
|
||||
wav_name = os.path.join(self.wav_dir,
|
||||
self.frames[idx][0]) + '.wav'
|
||||
mel_name = os.path.join(self.feat_dir,
|
||||
self.frames[idx][0]) + '.mel.npy'
|
||||
linear_name = os.path.join(self.feat_dir,
|
||||
self.frames[idx][0]) + '.linear.npy'
|
||||
text = self.frames[idx][1]
|
||||
text = np.asarray(text_to_sequence(
|
||||
text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||
mel = self.load_np(mel_name)
|
||||
linear = self.load_np(linear_name)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0],
|
||||
'mel':mel, 'linear': linear}
|
||||
self.items[idx] = sample
|
||||
else:
|
||||
sample = self.items[idx]
|
||||
return sample
|
||||
|
||||
def collate_fn(self, batch):
|
||||
r"""
|
||||
Perform preprocessing and create a final data batch:
|
||||
1. PAD sequences with the longest sequence in the batch
|
||||
2. Convert Audio signal to Spectrograms.
|
||||
3. PAD sequences that can be divided by r.
|
||||
4. Convert Numpy to Torch tensors.
|
||||
"""
|
||||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.Mapping):
|
||||
keys = list()
|
||||
|
||||
wav = [d['wav'] for d in batch]
|
||||
item_idxs = [d['item_idx'] for d in batch]
|
||||
text = [d['text'] for d in batch]
|
||||
mel = [d['mel'] for d in batch]
|
||||
linear = [d['linear'] for d in batch]
|
||||
|
||||
text_lenghts = np.array([len(x) for x in text])
|
||||
max_text_len = np.max(text_lenghts)
|
||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [np.array([0.]*(mel_len-1))
|
||||
for mel_len in mel_lengths]
|
||||
|
||||
# PAD stop targets
|
||||
stop_targets = prepare_stop_target(
|
||||
stop_targets, self.outputs_per_step)
|
||||
|
||||
# PAD sequences with largest length of the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
wav = prepare_data(wav)
|
||||
|
||||
# PAD features with largest length + a zero frame
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||
assert mel.shape[2] == linear.shape[2]
|
||||
timesteps = mel.shape[2]
|
||||
|
||||
# B x T x D
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
# convert things to pytorch
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text = torch.LongTensor(text)
|
||||
linear = torch.FloatTensor(linear)
|
||||
mel = torch.FloatTensor(mel)
|
||||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}"
|
||||
.format(type(batch[0]))))
|
|
@ -0,0 +1,88 @@
|
|||
'''
|
||||
Extract spectrograms and save them to file for training
|
||||
'''
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import glob
|
||||
import argparse
|
||||
import librosa
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from utils.audio import AudioProcessor
|
||||
from utils.generic_utils import load_config
|
||||
|
||||
from multiprocessing import Pool
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_path', type=str,
|
||||
help='Data folder.')
|
||||
parser.add_argument('--out_path', type=str,
|
||||
help='Output folder.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='conf.json file for run settings.')
|
||||
parser.add_argument("--num_proc", type=int, default=8,
|
||||
help="number of processes.")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(" > Input path: ", DATA_PATH)
|
||||
print(" > Output path: ", OUT_PATH)
|
||||
|
||||
DATA_PATH = args.data_path
|
||||
OUT_PATH = args.out_path
|
||||
CONFIG = load_config(args.config)
|
||||
ap = AudioProcessor(sample_rate = CONFIG.sample_rate,
|
||||
num_mels = CONFIG.num_mels,
|
||||
min_level_db = CONFIG.min_level_db,
|
||||
frame_shift_ms = CONFIG.frame_shift_ms,
|
||||
frame_length_ms = CONFIG.frame_length_ms,
|
||||
ref_level_db = CONFIG.ref_level_db,
|
||||
num_freq = CONFIG.num_freq,
|
||||
power = CONFIG.power,
|
||||
min_mel_freq = CONFIG.min_mel_freq,
|
||||
max_mel_freq = CONFIG.max_mel_freq)
|
||||
|
||||
def extract_mel(file_path):
|
||||
# x, fs = sf.read(file_path)
|
||||
x, fs = librosa.load(file_path, CONFIG.sample_rate)
|
||||
mel = ap.melspectrogram(x.astype('float32'))
|
||||
linear = ap.spectrogram(x.astype('float32'))
|
||||
file_name = os.path.basename(file_path).replace(".wav","")
|
||||
mel_file = file_name + ".mel"
|
||||
linear_file = file_name + ".linear"
|
||||
np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False)
|
||||
np.save(os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False)
|
||||
mel_len = mel.shape[1]
|
||||
linear_len = linear.shape[1]
|
||||
wav_len = x.shape[0]
|
||||
print(" > " + file_path, flush=True)
|
||||
return file_path, mel_file, linear_file, str(wav_len), str(mel_len), str(linear_len)
|
||||
|
||||
glob_path = os.path.join(DATA_PATH, "*.wav")
|
||||
print(" > Reading wav: {}".format(glob_path))
|
||||
file_names = glob.glob(glob_path, recursive=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(" > Number of files: %i"%(len(file_names)))
|
||||
if not os.path.exists(OUT_PATH):
|
||||
os.makedirs(OUT_PATH)
|
||||
print(" > A new folder created at {}".format(OUT_PATH))
|
||||
|
||||
r = []
|
||||
if args.num_proc > 1:
|
||||
print(" > Using {} processes.".format(args.num_proc))
|
||||
with Pool(args.num_proc) as p:
|
||||
r = list(tqdm.tqdm(p.imap(extract_mel, file_names), total=len(file_names)))
|
||||
else:
|
||||
print(" > Using single process run.")
|
||||
for file_name in file_names:
|
||||
print(" > ", file_name)
|
||||
r.append(extract_mel(file_name))
|
||||
|
||||
file_path = os.path.join(OUT_PATH, "meta_fftnet.csv")
|
||||
file = open(file_path, "w")
|
||||
for line in r:
|
||||
line = ", ".join(line)
|
||||
file.write(line+'\n')
|
||||
file.close()
|
51
train.py
51
train.py
|
@ -23,7 +23,6 @@ from utils.generic_utils import (synthesis, remove_experiment_folder,
|
|||
save_best_model, load_config, lr_decay,
|
||||
count_parameters, check_update, get_commit_hash)
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from datasets.LJSpeech import LJSpeechDataset
|
||||
from models.tacotron import Tacotron
|
||||
from layers.losses import L1LossMasked
|
||||
from utils.audio import AudioProcessor
|
||||
|
@ -40,7 +39,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
avg_linear_loss = 0
|
||||
avg_mel_loss = 0
|
||||
avg_stop_loss = 0
|
||||
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
||||
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
@ -100,7 +99,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
grad_norm, skip_flag = check_update(model, 0.5, 100)
|
||||
if skip_flag:
|
||||
optimizer.zero_grad()
|
||||
print(" | > Iteration skipped!!")
|
||||
print(" | > Iteration skipped!!", flush=True)
|
||||
continue
|
||||
optimizer.step()
|
||||
|
||||
|
@ -126,7 +125,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
stop_loss.item(),
|
||||
grad_norm.item(),
|
||||
grad_norm_st.item(),
|
||||
step_time))
|
||||
step_time), flush=True)
|
||||
|
||||
avg_linear_loss += linear_loss.item()
|
||||
avg_mel_loss += mel_loss.item()
|
||||
|
@ -185,7 +184,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
avg_linear_loss,
|
||||
avg_mel_loss,
|
||||
avg_stop_loss,
|
||||
epoch_time))
|
||||
epoch_time), flush=True)
|
||||
|
||||
# Plot Training Epoch Stats
|
||||
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
|
||||
|
@ -320,6 +319,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
|
||||
|
||||
def main(args):
|
||||
dataset = importlib.import_module('datasets.'+c.dataset)
|
||||
Dataset = getattr(dataset, 'MyDataset')
|
||||
|
||||
ap = AudioProcessor(sample_rate = c.sample_rate,
|
||||
num_mels = c.num_mels,
|
||||
min_level_db = c.min_level_db,
|
||||
|
@ -332,13 +334,13 @@ def main(args):
|
|||
max_mel_freq = c.max_mel_freq)
|
||||
|
||||
# Setup the dataset
|
||||
train_dataset = LJSpeechDataset(os.path.join(c.data_path, c.meta_file_train),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap = ap,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
train_dataset = Dataset(c.data_path,
|
||||
c.meta_file_train,
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap = ap,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
||||
shuffle=False, collate_fn=train_dataset.collate_fn,
|
||||
|
@ -346,12 +348,12 @@ def main(args):
|
|||
pin_memory=True)
|
||||
|
||||
if c.run_eval:
|
||||
val_dataset = LJSpeechDataset(os.path.join(c.data_path, c.meta_file_val),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap = ap
|
||||
)
|
||||
val_dataset = Dataset(c.data_path,
|
||||
c.meta_file_val,
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap = ap
|
||||
)
|
||||
|
||||
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
|
||||
shuffle=False, collate_fn=val_dataset.collate_fn,
|
||||
|
@ -374,6 +376,10 @@ def main(args):
|
|||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
if use_cuda:
|
||||
model = nn.DataParallel(model.cuda())
|
||||
criterion.cuda()
|
||||
criterion_st.cuda()
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
optimizer_st.load_state_dict(checkpoint['optimizer_st'])
|
||||
for state in optimizer.state.values():
|
||||
|
@ -387,11 +393,10 @@ def main(args):
|
|||
else:
|
||||
args.restore_step = 0
|
||||
print("\n > Starting a new training")
|
||||
|
||||
if use_cuda:
|
||||
model = nn.DataParallel(model.cuda())
|
||||
criterion.cuda()
|
||||
criterion_st.cuda()
|
||||
if use_cuda:
|
||||
model = nn.DataParallel(model.cuda())
|
||||
criterion.cuda()
|
||||
criterion_st.cuda()
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print(" | > Model has {} parameters".format(num_params))
|
||||
|
|
Loading…
Reference in New Issue