Batch update after data-loss

pull/10/head
Eren Golge 2018-11-02 16:13:51 +01:00
parent f53f9cb360
commit c8a552e627
18 changed files with 1362 additions and 607 deletions

View File

@ -1,41 +1,47 @@
{
"model_name": "TTS-sigmoid",
"model_description": "Net outputting Sigmoid unit",
"audio_processor": "audio",
"num_mels": 80,
"num_freq": 1025,
"sample_rate": 22000,
"frame_length_ms": 50,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"model_name": "TTS-master-_tmp",
"model_description": "Higher dropout rate for stopnet and disabled custom initialization, pull current mel prediction to stopnet.",
"audio":{
"audio_processor": "audio", // to use dictate different audio processors, if available.
"num_mels": 80, // size of the mel spec frame.
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled.
"frame_length_ms": 50, // stft window length in ms.
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
"preemphasis": 0.97, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": false, // move normalization to range [-1, 1]
"max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"mel_fmin": null, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": null // maximum freq level for mel-spec. Tune for dataset!!
},
"embedding_size": 256,
"text_cleaner": "english_cleaners",
"num_loader_workers": 4,
"epochs": 1000,
"lr": 0.002,
"lr": 0.0015,
"warmup_steps": 4000,
"lr_decay": 0.5,
"decay_step": 100000,
"batch_size": 32,
"eval_batch_size":-1,
"r": 5,
"wd": 0.0001,
"griffin_lim_iters": 60,
"power": 1.5,
"batch_size":32,
"eval_batch_size":32,
"r": 1,
"wd": 0.000001,
"checkpoint": true,
"save_step": 25000,
"save_step": 5000,
"print_step": 10,
"run_eval": false,
"data_path": "/snakepit/shared/data/keithito/LJSpeech-1.1/",
"meta_file_train": "metadata.csv",
"meta_file_val": null,
"dataset": "LJSpeech",
"run_eval": true,
"data_path": "../../Data/LJSpeech-1.1/tts_cache", // can overwritten from command argument
"meta_file_train": "metadata_train.csv", // metafile for training dataloader
"meta_file_val": "metadata_val.csv", // metafile for validation dataloader
"data_loader": "TTSDataset", // dataloader, ["TTSDataset", "TTSDatasetCached", "TTSDatasetMemory"]
"dataset": "ljspeech", // one of TTS.dataset.preprocessors, only valid id dataloader == "TTSDataset", rest uses "tts_cache" by default.
"min_seq_len": 0,
"output_path": "../keep/"
"output_path": "../keep/",
"num_loader_workers": 8
}

View File

@ -13,75 +13,72 @@ from utils.data import (prepare_data, pad_per_step, prepare_tensor,
class MyDataset(Dataset):
def __init__(self,
root_dir,
csv_file,
root_path,
meta_file,
outputs_per_step,
text_cleaner,
ap,
preprocessor,
batch_group_size=0,
min_seq_len=0):
self.root_dir = root_dir
self.root_path = root_path
self.batch_group_size = batch_group_size
self.wav_dir = os.path.join(root_dir, 'wavs')
self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f:
self.frames = [line.split('|') for line in f]
self.items = preprocessor(root_path, meta_file)
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.min_seq_len = min_seq_len
self.ap = ap
print(" > Reading LJSpeech from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
self.sort_frames()
print(" > Reading LJSpeech from - {}".format(root_path))
print(" | > Number of instances : {}".format(len(self.items)))
self.sort_items()
def load_wav(self, filename):
try:
audio = librosa.core.load(filename, sr=self.sample_rate)[0]
audio = self.ap.load_wav(filename)
return audio
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def sort_frames(self):
def sort_items(self):
r"""Sort text sequences in ascending order"""
lengths = np.array([len(ins[1]) for ins in self.frames])
lengths = np.array([len(ins[0]) for ins in self.items])
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths)))
idxs = np.argsort(lengths)
new_frames = []
new_items = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len:
ignored.append(idx)
else:
new_frames.append(self.frames[idx])
new_items.append(self.items[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len))
# shuffle batch groups
if self.batch_group_size > 0:
print(" | > Batch group shuffling is active.")
for i in range(len(new_frames) // self.batch_group_size):
for i in range(len(new_items) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_frames = new_frames[offset : end_offset]
random.shuffle(temp_frames)
new_frames[offset : end_offset] = temp_frames
self.frames = new_frames
temp_items = new_items[offset : end_offset]
random.shuffle(temp_items)
new_items[offset : end_offset] = temp_items
self.items = new_items
def __len__(self):
return len(self.frames)
return len(self.items)
def __getitem__(self, idx):
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
text = self.frames[idx][1]
text, wav_file = self.items[idx]
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]}
return sample
def collate_fn(self, batch):

View File

@ -1,4 +1,5 @@
import os
import random
import numpy as np
import collections
import librosa
@ -6,32 +7,34 @@ import torch
from torch.utils.data import Dataset
from utils.text import text_to_sequence
from datasets.preprocess import tts_cache
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
class MyDataset(Dataset):
# TODO: Not finished yet.
def __init__(self,
root_dir,
csv_file,
root_path,
meta_file,
outputs_per_step,
text_cleaner,
ap,
min_seq_len=0):
self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wavs')
self.feat_dir = os.path.join(root_dir, 'loader_data')
self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f:
self.frames = [line.split('|') for line in f]
batch_group_size=0,
min_seq_len=0,
**kwargs
):
self.root_path = root_path
self.batch_group_size = batch_group_size
self.feat_dir = os.path.join(root_path, 'loader_data')
self.items = tts_cache(root_path, meta_file)
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.min_seq_len = min_seq_len
self.items = [None] * len(self.frames)
print(" > Reading LJSpeech from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames()
print(" > Reading LJSpeech from - {}".format(root_path))
print(" | > Number of instances : {}".format(len(self.items)))
self.sort_items()
def load_wav(self, filename):
try:
@ -44,9 +47,9 @@ class MyDataset(Dataset):
data = np.load(filename).astype('float32')
return data
def _sort_frames(self):
r"""Sort sequences in ascending order"""
lengths = np.array([len(ins[1]) for ins in self.frames])
def sort_items(self):
r"""Sort text sequences in ascending order"""
lengths = np.array([len(ins[-1]) for ins in self.items])
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
@ -60,37 +63,43 @@ class MyDataset(Dataset):
if length < self.min_seq_len:
ignored.append(idx)
else:
new_frames.append(self.frames[idx])
new_frames.append(self.items[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len))
self.frames = new_frames
# shuffle batch groups
if self.batch_group_size > 0:
print(" | > Batch group shuffling is active.")
for i in range(len(new_frames) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_frames = new_frames[offset : end_offset]
random.shuffle(temp_frames)
new_frames[offset : end_offset] = temp_frames
self.items = new_frames
def __len__(self):
return len(self.frames)
return len(self.items)
def __getitem__(self, idx):
if self.items[idx] is None:
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
mel_name = os.path.join(self.feat_dir,
self.frames[idx][0]) + '.mel.npy'
linear_name = os.path.join(self.feat_dir,
self.frames[idx][0]) + '.linear.npy'
text = self.frames[idx][1]
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
mel = self.load_np(mel_name)
linear = self.load_np(linear_name)
sample = {
'text': text,
'wav': wav,
'item_idx': self.frames[idx][0],
'mel': mel,
'linear': linear
}
self.items[idx] = sample
wav_name = self.items[idx][0]
mel_name = self.items[idx][1]
linear_name = self.items[idx][2]
text = self.items[idx][-1]
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
if wav_name.split('.')[-1] == 'npy':
wav = self.load_np(wav_name)
else:
sample = self.items[idx]
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
mel = self.load_np(mel_name)
linear = self.load_np(linear_name)
sample = {
'text': text,
'wav': wav,
'item_idx': self.items[idx][0],
'mel': mel,
'linear': linear
}
return sample
def collate_fn(self, batch):
@ -132,7 +141,6 @@ class MyDataset(Dataset):
# PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
timesteps = mel.shape[2]
# B x T x D
@ -147,8 +155,7 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0]))))

View File

@ -1,162 +1,179 @@
import os
import glob
import random
import numpy as np
import collections
import librosa
import torch
from torch.utils.data import Dataset
from utils.text import text_to_sequence
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
class MyDataset(Dataset):
def __init__(self,
root_dir,
csv_file,
outputs_per_step,
text_cleaner,
ap,
min_seq_len=0):
self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wav')
self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav'))
self._create_file_dict()
self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f:
self.frames = [
line.split('\t') for line in f
if line.split('\t')[0] in self.wav_files_dict.keys()
]
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.min_seq_len = min_seq_len
self.ap = ap
print(" > Reading Kusal from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames()
def load_wav(self, filename):
""" Load audio and trim silence """
try:
audio = librosa.core.load(filename, sr=self.sample_rate)[0]
margin = int(self.sample_rate * 0.1)
audio = audio[margin:-margin]
return self._trim_silence(audio)
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def _trim_silence(self, wav):
return librosa.effects.trim(
wav, top_db=40, frame_length=1024, hop_length=256)[0]
def _create_file_dict(self):
self.wav_files_dict = {}
for fn in self.wav_files:
parts = fn.split('-')
key = parts[1]
value = fn
try:
self.wav_files_dict[key].append(value)
except:
self.wav_files_dict[key] = [value]
def _sort_frames(self):
r"""Sort sequences in ascending order"""
lengths = np.array([len(ins[2]) for ins in self.frames])
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths)))
idxs = np.argsort(lengths)
new_frames = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len:
ignored.append(idx)
else:
new_frames.append(self.frames[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len))
self.frames = new_frames
def __len__(self):
return len(self.frames)
def __getitem__(self, idx):
sidx = self.frames[idx][0]
sidx_files = self.wav_files_dict[sidx]
file_name = random.choice(sidx_files)
wav_name = os.path.join(self.wav_dir, file_name)
text = self.frames[idx][2]
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. PAD sequences with the longest sequence in the batch
2. Convert Audio signal to Spectrograms.
3. PAD sequences that can be divided by r.
4. Convert Numpy to Torch tensors.
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
keys = list()
wav = [d['wav'] for d in batch]
item_idxs = [d['item_idx'] for d in batch]
text = [d['text'] for d in batch]
text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts)
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets
stop_targets = [
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
# PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
timesteps = mel.shape[2]
# B x T x D
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear)
mel = torch.FloatTensor(mel)
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0]))))
import os
import random
import numpy as np
import collections
import librosa
import torch
from tqdm import tqdm
from torch.utils.data import Dataset
from utils.text import text_to_sequence
from datasets.preprocess import tts_cache
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
class MyDataset(Dataset):
# TODO: Not finished yet.
def __init__(self,
root_path,
meta_file,
outputs_per_step,
text_cleaner,
ap,
batch_group_size=0,
min_seq_len=0,
**kwargs
):
self.root_path = root_path
self.batch_group_size = batch_group_size
self.feat_dir = os.path.join(root_path, 'loader_data')
self.items = tts_cache(root_path, meta_file)
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.min_seq_len = min_seq_len
self.wavs = None
self.mels = None
self.linears = None
print(" > Reading LJSpeech from - {}".format(root_path))
print(" | > Number of instances : {}".format(len(self.items)))
self.sort_items()
self.fill_data()
def fill_data(self):
if self.wavs is None and self.mels is None:
self.wavs = []
self.mels = []
self.linears = []
self.texts = []
for item in tqdm(self.items):
wav_file = item[0]
mel_file = item[1]
linear_file = item[2]
text = item[-1]
wav = self.load_np(wav_file)
mel = self.load_np(mel_file)
linear = self.load_np(linear_file)
self.wavs.append(wav)
self.mels.append(mel)
self.linears.append(linear)
self.texts.append(np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32))
print(" > Data loaded to memory")
def load_wav(self, filename):
try:
audio = librosa.core.load(filename, sr=self.sample_rate)
return audio
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def load_np(self, filename):
data = np.load(filename).astype('float32')
return data
def sort_items(self):
r"""Sort text sequences in ascending order"""
lengths = np.array([len(ins[-1]) for ins in self.items])
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths)))
idxs = np.argsort(lengths)
new_frames = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len:
ignored.append(idx)
else:
new_frames.append(self.items[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len))
# shuffle batch groups
if self.batch_group_size > 0:
print(" | > Batch group shuffling is active.")
for i in range(len(new_frames) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_frames = new_frames[offset : end_offset]
random.shuffle(temp_frames)
new_frames[offset : end_offset] = temp_frames
self.items = new_frames
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
text = self.texts[idx]
wav = self.wavs[idx]
mel = self.mels[idx]
linear = self.linears[idx]
sample = {
'text': text,
'wav': wav,
'item_idx': self.items[idx][0],
'mel': mel,
'linear': linear
}
return sample
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. PAD sequences with the longest sequence in the batch
2. Convert Audio signal to Spectrograms.
3. PAD sequences that can be divided by r.
4. Convert Numpy to Torch tensors.
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
keys = list()
wav = [d['wav'] for d in batch]
item_idxs = [d['item_idx'] for d in batch]
text = [d['text'] for d in batch]
mel = [d['mel'] for d in batch]
linear = [d['linear'] for d in batch]
text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts)
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets
stop_targets = [
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
# PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
timesteps = mel.shape[2]
# B x T x D
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear)
mel = torch.FloatTensor(mel)
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0]))))

60
datasets/preprocess.py Normal file
View File

@ -0,0 +1,60 @@
import os
import random
def tts_cache(root_path, meta_file):
"""This format is set for the meta-file generated by extract_features.py"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, 'r', encoding='utf8') as f:
for line in f:
cols = line.split('| ')
items.append(cols) # wav_full_path, mel_name, linear_name, wav_len, mel_len, text
random.shuffle(items)
return items
# def tweb(root_path, meta_file):
# # TODO
# pass
# return
# def kusal(root_path, meta_file):
# txt_file = os.path.join(root_path, meta_file)
# texts = []
# wavs = []
# with open(txt_file, "r", encoding="utf8") as f:
# frames = [
# line.split('\t') for line in f
# if line.split('\t')[0] in self.wav_files_dict.keys()
# ]
# # TODO: code the rest
# return {'text': texts, 'wavs': wavs}
def ljspeech(root_path, meta_file):
"""Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('|')
wav_file = os.path.join(root_path, 'wavs', cols[0]+'.wav')
text = cols[1]
items.append([text, wav_file])
random.shuffle(items)
return items
def nancy(root_path, meta_file):
"""Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, 'r') as ttf:
for line in ttf:
id = line.split()[1]
text = line[line.find('"')+1:line.rfind('"')-1]
wav_file = root_path + 'wavn/' + id + '.wav'
items.append(text, wav_file)
random.shuffle(items)
return items

138
extract_features.py Normal file
View File

@ -0,0 +1,138 @@
'''
Extract spectrograms and save them to file for training
'''
import os
import sys
import time
import glob
import argparse
import librosa
import importlib
import numpy as np
import tqdm
from utils.generic_utils import load_config, copy_config_file
from utils.audio import AudioProcessor
from multiprocessing import Pool
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, help='Data folder.')
parser.add_argument('--cache_path', type=str, help='Cache folder, place to output all the intermediate spectrogram files.')
# parser.add_argument('--keep_cache', type=bool, help='If True, it keeps the cache folder.')
# parser.add_argument('--hdf5_path', type=str, help='hdf5 folder.')
parser.add_argument(
'--config', type=str, help='conf.json file for run settings.')
parser.add_argument(
"--num_proc", type=int, default=8, help="number of processes.")
parser.add_argument(
"--trim_silence",
type=bool,
default=False,
help="trim silence in the voice clip.")
parser.add_argument("--only_mel", type=bool, default=False, help="If True, only melsceptrogram is extracted.")
parser.add_argument("--dataset", type=str, help="Target dataset to be processed.")
parser.add_argument("--val_split", type=int, default=0, help="Number of instances for validation.")
parser.add_argument("--meta_file", type=str, help="Meta data file to be used for the dataset.")
parser.add_argument("--process_audio", type=bool, default=False, help="Preprocess audio files.")
args = parser.parse_args()
DATA_PATH = args.data_path
CACHE_PATH = args.cache_path
CONFIG = load_config(args.config)
# load the right preprocessor
preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, args.dataset.lower())
items = preprocessor(args.data_path, args.meta_file)
print(" > Input path: ", DATA_PATH)
print(" > Cache path: ", CACHE_PATH)
# audio = importlib.import_module('utils.' + c.audio_processor)
# AudioProcessor = getattr(audio, 'AudioProcessor')
ap = AudioProcessor(**CONFIG.audio)
def trim_silence(self, wav):
""" Trim silent parts with a threshold and 0.1 sec margin """
margin = int(ap.sample_rate * 0.1)
wav = wav[margin:-margin]
return librosa.effects.trim(
wav, top_db=40, frame_length=1024, hop_length=256)[0]
def extract_mel(item):
""" Compute spectrograms, length information """
text = item[0]
file_path = item[1]
x = ap.load_wav(file_path, ap.sample_rate)
if args.trim_silence:
x = trim_silence(x)
file_name = os.path.basename(file_path).replace(".wav", "")
mel_file = file_name + "_mel"
mel_path = os.path.join(CACHE_PATH, 'mel', mel_file)
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
np.save(mel_path, mel, allow_pickle=False)
mel_len = mel.shape[1]
wav_len = x.shape[0]
output = [file_path, mel_path+".npy", str(wav_len), str(mel_len), text]
if not args.only_mel:
linear_file = file_name + "_linear"
linear_path = os.path.join(CACHE_PATH, 'linear', linear_file)
linear = ap.spectrogram(x.astype('float32')).astype('float32')
linear_len = linear.shape[1]
np.save(linear_path, linear, allow_pickle=False)
output.insert(2, linear_path+".npy")
if args.process_audio:
audio_file = file_name + "_audio"
audio_path = os.path.join(CACHE_PATH, 'audio', audio_file)
np.save(audio_path, x, allow_pickle=False)
del output[0]
output.insert(0, audio_path+".npy")
assert mel_len == linear_len
return output
if __name__ == "__main__":
print(" > Number of files: %i" % (len(items)))
if not os.path.exists(CACHE_PATH):
os.makedirs(os.path.join(CACHE_PATH, 'mel'))
if not args.only_mel:
os.makedirs(os.path.join(CACHE_PATH, 'linear'))
if args.process_audio:
os.makedirs(os.path.join(CACHE_PATH, 'audio'))
print(" > A new folder created at {}".format(CACHE_PATH))
# Extract features
r = []
if args.num_proc > 1:
print(" > Using {} processes.".format(args.num_proc))
with Pool(args.num_proc) as p:
r = list(
tqdm.tqdm(
p.imap(extract_mel, items),
total=len(items)))
# r = list(p.imap(extract_mel, file_names))
else:
print(" > Using single process run.")
for item in items:
print(" > ", item[1])
r.append(extract_mel(item))
# Save meta data
if args.cache_path is not None:
file_path = os.path.join(CACHE_PATH, "tts_metadata_val.csv")
file = open(file_path, "w")
for line in r[:args.val_split]:
line = "| ".join(line)
file.write(line + '\n')
file.close()
file_path = os.path.join(CACHE_PATH, "tts_metadata.csv")
file = open(file_path, "w")
for line in r[args.val_split:]:
line = "| ".join(line)
file.write(line + '\n')
file.close()
# copy the used config file to output path for sanity
copy_config_file(args.config, CACHE_PATH)

View File

@ -52,10 +52,25 @@ class LocationSensitiveAttention(nn.Module):
stride=1,
padding=0,
bias=False))
self.loc_linear = nn.Linear(filters, attn_dim)
self.loc_linear = nn.Linear(filters, attn_dim, bias=True)
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
self.v = nn.Linear(attn_dim, 1, bias=False)
# self.init_layers()
def init_layers(self):
torch.nn.init.xavier_uniform_(
self.loc_linear.weight,
gain=torch.nn.init.calculate_gain('tanh'))
torch.nn.init.xavier_uniform_(
self.query_layer.weight,
gain=torch.nn.init.calculate_gain('tanh'))
torch.nn.init.xavier_uniform_(
self.annot_layer.weight,
gain=torch.nn.init.calculate_gain('tanh'))
torch.nn.init.xavier_uniform_(
self.v.weight,
gain=torch.nn.init.calculate_gain('linear'))
def forward(self, annot, query, loc):
"""

View File

@ -23,6 +23,13 @@ class Prenet(nn.Module):
])
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
# self.init_layers()
def init_layers(self):
for layer in self.layers:
torch.nn.init.xavier_uniform_(
layer.weight,
gain=torch.nn.init.calculate_gain('relu'))
def forward(self, inputs):
for linear in self.layers:
@ -55,6 +62,7 @@ class BatchNormConv1d(nn.Module):
stride,
padding,
activation=None):
super(BatchNormConv1d, self).__init__()
self.padding = padding
self.padder = nn.ConstantPad1d(padding, 0)
@ -68,13 +76,28 @@ class BatchNormConv1d(nn.Module):
# Following tensorflow's default parameters
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
self.activation = activation
# self.init_layers()
def init_layers(self):
if type(self.activation) == torch.nn.ReLU:
w_gain = 'relu'
elif type(self.activation) == torch.nn.Tanh:
w_gain = 'tanh'
elif self.activation is None:
w_gain = 'linear'
else:
raise RuntimeError('Unknown activation function')
torch.nn.init.xavier_uniform_(
self.conv1d.weight,
gain=torch.nn.init.calculate_gain(w_gain))
def forward(self, x):
x = self.padder(x)
x = self.conv1d(x)
x = self.bn(x)
if self.activation is not None:
x = self.activation(x)
return self.bn(x)
return x
class Highway(nn.Module):
@ -86,6 +109,15 @@ class Highway(nn.Module):
self.T.bias.data.fill_(-1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
# self.init_layers()
def init_layers(self):
torch.nn.init.xavier_uniform_(
self.H.weight,
gain=torch.nn.init.calculate_gain('relu'))
torch.nn.init.xavier_uniform_(
self.T.weight,
gain=torch.nn.init.calculate_gain('sigmoid'))
def forward(self, inputs):
H = self.relu(self.H(inputs))
@ -276,7 +308,7 @@ class Decoder(nn.Module):
super(Decoder, self).__init__()
self.r = r
self.in_features = in_features
self.max_decoder_steps = 200
self.max_decoder_steps = 500
self.memory_dim = memory_dim
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
@ -294,7 +326,16 @@ class Decoder(nn.Module):
[nn.GRUCell(256, 256) for _ in range(2)])
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = StopNet(r, memory_dim)
self.stopnet = StopNet(256 + memory_dim * r)
# self.init_layers()
def init_layers(self):
torch.nn.init.xavier_uniform_(
self.project_to_decoder_in.weight,
gain=torch.nn.init.calculate_gain('linear'))
torch.nn.init.xavier_uniform_(
self.proj_to_mel.weight,
gain=torch.nn.init.calculate_gain('linear'))
def forward(self, inputs, memory=None, mask=None):
"""
@ -350,7 +391,7 @@ class Decoder(nn.Module):
memory_input = initial_memory
while True:
if t > 0:
if greedy:
if memory is None:
memory_input = outputs[-1]
else:
memory_input = memory[t - 1]
@ -375,17 +416,14 @@ class Decoder(nn.Module):
decoder_output = decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(decoder_output)
output = torch.sigmoid(output)
stop_input = output
# predict stop token
stop_token, stopnet_rnn_hidden = self.stopnet(
stop_input, stopnet_rnn_hidden)
stopnet_input = torch.cat([decoder_input, output], -1)
stop_token = self.stopnet(stopnet_input)
outputs += [output]
attentions += [attention]
stop_tokens += [stop_token]
t += 1
if (not greedy and self.training) or (greedy
and memory is not None):
if memory is not None:
if t >= T_decoder:
break
else:
@ -394,7 +432,6 @@ class Decoder(nn.Module):
elif t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
break
assert greedy or len(outputs) == T_decoder
# Back to batch first
attentions = torch.stack(attentions).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
@ -407,32 +444,22 @@ class StopNet(nn.Module):
Predicting stop-token in decoder.
Args:
r (int): number of output frames of the network.
memory_dim (int): feature dimension for each output frame.
in_features (int): feature dimension of input.
"""
def __init__(self, r, memory_dim):
r"""
Predicts the stop token to stop the decoder at testing time
Args:
r (int): number of network output frames.
memory_dim (int): single feature dim of a single network output frame.
"""
def __init__(self, in_features):
super(StopNet, self).__init__()
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
self.relu = nn.ReLU()
self.linear = nn.Linear(r * memory_dim, 1)
self.dropout = nn.Dropout(0.5)
self.linear = nn.Linear(in_features, 1)
self.sigmoid = nn.Sigmoid()
torch.nn.init.xavier_uniform_(
self.linear.weight,
gain=torch.nn.init.calculate_gain('linear'))
def forward(self, inputs, rnn_hidden):
"""
Args:
inputs: network output tensor with r x memory_dim feature dimension.
rnn_hidden: hidden state of the RNN cell.
"""
rnn_hidden = self.rnn(inputs, rnn_hidden)
outputs = self.relu(rnn_hidden)
def forward(self, inputs):
# rnn_hidden = self.rnn(inputs, rnn_hidden)
# outputs = self.relu(rnn_hidden)
outputs = self.dropout(inputs)
outputs = self.linear(outputs)
outputs = self.sigmoid(outputs)
return outputs, rnn_hidden
return outputs

View File

@ -80,11 +80,11 @@ setup(
"matplotlib==2.0.2",
"Pillow",
"flask",
"lws",
# "lws",
"tqdm",
],
extras_require={
"bin": [
"tqdm",
"requests",
],
})

142
tests/audio_tests.py Normal file
View File

@ -0,0 +1,142 @@
import os
import unittest
import numpy as np
import torch as T
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import load_config
file_path = os.path.dirname(os.path.realpath(__file__))
INPUTPATH = os.path.join(file_path, 'inputs')
OUTPATH = os.path.join(file_path, "outputs/audio_tests")
os.makedirs(OUTPATH, exist_ok=True)
c = load_config(os.path.join(file_path, 'test_config.json'))
class TestAudio(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestAudio, self).__init__(*args, **kwargs)
self.ap = AudioProcessor(**c.audio)
def test_audio_synthesis(self):
""" 1. load wav
2. set normalization parameters
3. extract mel-spec
4. invert to wav and save the output
"""
print(" > Sanity check for the process wav -> mel -> wav")
def _test(max_norm, signal_norm, symmetric_norm, clip_norm):
self.ap.max_norm = max_norm
self.ap.signal_norm = signal_norm
self.ap.symmetric_norm = symmetric_norm
self.ap.clip_norm = clip_norm
wav = self.ap.load_wav(INPUTPATH + "/example_1.wav")
mel = self.ap.melspectrogram(wav)
wav_ = self.ap.inv_mel_spectrogram(mel)
file_name = "/audio_test-melspec_max_norm_{}-signal_norm_{}-symmetric_{}-clip_norm_{}.wav"\
.format(max_norm, signal_norm, symmetric_norm, clip_norm)
print(" | > Creating wav file at : ", file_name)
self.ap.save_wav(wav_, OUTPATH + file_name)
# maxnorm = 1.0
_test(1., False, False, False)
_test(1., True, False, False)
_test(1., True, True, False)
_test(1., True, False, True)
_test(1., True, True, True)
# maxnorm = 4.0
_test(4., False, False, False)
_test(4., True, False, False)
_test(4., True, True, False)
_test(4., True, False, True)
_test(4., True, True, True)
def test_normalize(self):
"""Check normalization and denormalization for range values and consistency """
print(" > Testing normalization and denormalization.")
wav = self.ap.load_wav(INPUTPATH + "/example_1.wav")
self.ap.signal_norm = False
x = self.ap.melspectrogram(wav)
x_old = x
self.ap.signal_norm = True
self.ap.symmetric_norm = False
self.ap.clip_norm = False
self.ap.max_norm = 4.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
# check value range
assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
assert x_norm.min() >= 0 - 1, x_norm.min()
# check denorm.
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3, (x - x_).mean()
self.ap.signal_norm = True
self.ap.symmetric_norm = False
self.ap.clip_norm = True
self.ap.max_norm = 4.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
# check value range
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
assert x_norm.min() >= 0, x_norm.min()
# check denorm.
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3, (x - x_).mean()
self.ap.signal_norm = True
self.ap.symmetric_norm = True
self.ap.clip_norm = False
self.ap.max_norm = 4.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
# check value range
assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
assert x_norm.min() >= -self.ap.max_norm - 2, x_norm.min()
assert x_norm.min() <= 0, x_norm.min()
# check denorm.
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3, (x - x_).mean()
self.ap.signal_norm = True
self.ap.symmetric_norm = True
self.ap.clip_norm = True
self.ap.max_norm = 4.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
# check value range
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
assert x_norm.min() >= -self.ap.max_norm, x_norm.min()
assert x_norm.min() <= 0, x_norm.min()
# check denorm.
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3, (x - x_).mean()
self.ap.signal_norm = True
self.ap.symmetric_norm = False
self.ap.max_norm = 1.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
assert x_norm.min() >= 0, x_norm.min()
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3
self.ap.signal_norm = True
self.ap.symmetric_norm = True
self.ap.max_norm = 1.0
x_norm = self.ap._normalize(x)
print(x_norm.max(), " -- ", x_norm.min())
assert (x_old - x).sum() == 0
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
assert x_norm.min() >= -self.ap.max_norm, x_norm.min()
assert x_norm.min() < 0, x_norm.min()
x_ = self.ap._denormalize(x_norm)
assert (x - x_).sum() < 1e-3

View File

@ -1,50 +1,49 @@
import os
import unittest
import shutil
import numpy as np
from torch.utils.data import DataLoader
from TTS.utils.generic_utils import load_config
from TTS.utils.audio import AudioProcessor
from TTS.datasets import LJSpeech, Kusal
from TTS.datasets import TTSDataset, TTSDatasetCached, TTSDatasetMemory
from TTS.datasets.preprocess import ljspeech, tts_cache
file_path = os.path.dirname(os.path.realpath(__file__))
OUTPATH = os.path.join(file_path, "outputs/loader_tests/")
os.makedirs(OUTPATH, exist_ok=True)
c = load_config(os.path.join(file_path, 'test_config.json'))
ok_kusal = os.path.exists(c.data_path_Kusal)
ok_ljspeech = os.path.exists(c.data_path_LJSpeech)
ok_ljspeech = os.path.exists(c.data_path)
class TestLJSpeechDataset(unittest.TestCase):
class TestTTSDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
super(TestTTSDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
self.ap = AudioProcessor(
sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis)
self.ap = AudioProcessor(**c.audio)
def _create_dataloader(self, batch_size, r, bgs):
dataset = TTSDataset.MyDataset(
c.data_path,
'metadata.csv',
r,
c.text_cleaner,
preprocessor=ljspeech,
ap=self.ap,
batch_group_size=bgs,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
return dataloader, dataset
def test_loader(self):
if ok_ljspeech:
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.r,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
dataloader, dataset = self._create_dataloader(2, c.r, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
@ -63,29 +62,158 @@ class TestLJSpeechDataset(unittest.TestCase):
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert linear_input.shape[2] == self.ap.num_freq
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
assert mel_input.shape[2] == c.audio['num_mels']
# check normalization ranges
if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= -self.ap.max_norm
assert mel_input.min() < 0
else:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= 0
def test_batch_group_shuffle(self):
if ok_ljspeech:
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.r,
c.text_cleaner,
ap=self.ap,
batch_group_size=16,
min_seq_len=c.min_seq_len)
dataloader, dataset = self._create_dataloader(2, c.r, 16)
last_length = 0
frames = dataset.items
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
dataloader.dataset.sort_items()
assert frames[0] != dataloader.dataset.items[0]
frames = dataset.frames
def test_padding_and_spec(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(1, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# check mel_spec consistency
wav = self.ap.load_wav(item_idx[0])
mel = self.ap.melspectrogram(wav)
mel_dl = mel_input[0].cpu().numpy()
assert (
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader.wav')
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader.wav')
# check linear-spec
linear_spec = linear_input[0].cpu().numpy()
wav = self.ap.inv_spectrogram(linear_spec.T)
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav')
shutil.copy(item_idx[0], OUTPATH + '/linear_target_dataloader.wav')
# check the last time step to be zero padded
assert linear_input[0, -1].sum() == 0
assert linear_input[0, -2].sum() != 0
assert mel_input[0, -1].sum() == 0
assert mel_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == linear_input[0].shape[0]
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
dataloader, dataset = self._create_dataloader(2, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# check the first item in the batch
assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0, linear_input
assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
assert mel_lengths[idx] == linear_input[idx].shape[0]
# check the second itme in the batch
assert linear_input[1 - idx, -1].sum() == 0
assert mel_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check batch conditions
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
class TestTTSDatasetCached(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestTTSDatasetCached, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
self.c = load_config(os.path.join(c.data_path_cache, 'config.json'))
self.ap = AudioProcessor(**self.c.audio)
def _create_dataloader(self, batch_size, r, bgs):
dataset = TTSDatasetCached.MyDataset(
c.data_path_cache,
'tts_metadata.csv',
r,
c.text_cleaner,
preprocessor=tts_cache,
ap=self.ap,
batch_group_size=bgs,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
return dataloader, dataset
def test_loader(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
@ -102,32 +230,21 @@ class TestLJSpeechDataset(unittest.TestCase):
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
dataloader.dataset.sort_frames()
assert frames[0] != dataloader.dataset.frames[0]
assert mel_input.shape[2] == c.audio['num_mels']
if self.ap.symmetric_norm:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= -self.ap.max_norm
assert mel_input.min() < 0
else:
assert mel_input.max() <= self.ap.max_norm
assert mel_input.min() >= 0
def test_padding(self):
def test_batch_group_shuffle(self):
if ok_ljspeech:
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
1,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
# Test for batch size 1
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
dataloader, dataset = self._create_dataloader(2, c.r, 16)
frames = dataset.items
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
@ -139,11 +256,51 @@ class TestLJSpeechDataset(unittest.TestCase):
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.audio['num_mels']
dataloader.dataset.sort_items()
assert frames[0] != dataloader.dataset.items[0]
def test_padding_and_spec(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(1, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# check mel_spec consistency
if item_idx[0].split('.')[-1] == 'npy':
wav = np.load(item_idx[0])
else:
wav = self.ap.load_wav(item_idx[0])
mel = self.ap.melspectrogram(wav)
mel_dl = mel_input[0].cpu().numpy()
assert (abs(mel.T).astype("float32") - abs(
mel_dl[:-1])).sum() == 0, (
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum()
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
self.ap.save_wav(wav,
OUTPATH + '/mel_inv_dataloader_cache.wav')
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_cache.wav')
# check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0
assert mel_input[0, -2].sum() != 0
assert linear_input[0, -1].sum() == 0
assert linear_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
@ -151,14 +308,7 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
dataloader, dataset = self._create_dataloader(2, 1, 0)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
@ -178,8 +328,6 @@ class TestLJSpeechDataset(unittest.TestCase):
# check the first item in the batch
assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
@ -188,151 +336,202 @@ class TestLJSpeechDataset(unittest.TestCase):
# check the second itme in the batch
assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
class TestKusalDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestKusalDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
self.ap = AudioProcessor(
sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis)
# class TestTTSDatasetMemory(unittest.TestCase):
# def __init__(self, *args, **kwargs):
# super(TestTTSDatasetMemory, self).__init__(*args, **kwargs)
# self.max_loader_iter = 4
# self.c = load_config(os.path.join(c.data_path_cache, 'config.json'))
# self.ap = AudioProcessor(**c.audio)
def test_loader(self):
if ok_kusal:
dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
c.r,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
# def test_loader(self):
# if ok_ljspeech:
# dataset = TTSDatasetMemory.MyDataset(
# c.data_path_cache,
# 'tts_metadata.csv',
# c.r,
# c.text_cleaner,
# preprocessor=tts_cache,
# ap=self.ap,
# min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
# dataloader = DataLoader(
# dataset,
# batch_size=2,
# shuffle=True,
# collate_fn=dataset.collate_fn,
# drop_last=True,
# num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
# neg_values = text_input[text_input < 0]
# check_count = len(neg_values)
# assert check_count == 0, \
# " !! Negative values in text_input: {}".format(check_count)
# # check mel-spec shape
# assert mel_input.shape[0] == c.batch_size
# assert mel_input.shape[2] == c.audio['num_mels']
# assert mel_input.max() <= self.ap.max_norm
# # check data range
# if self.ap.symmetric_norm:
# assert mel_input.max() <= self.ap.max_norm
# assert mel_input.min() >= -self.ap.max_norm
# assert mel_input.min() < 0
# else:
# assert mel_input.max() <= self.ap.max_norm
# assert mel_input.min() >= 0
def test_padding(self):
if ok_kusal:
dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
1,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
# def test_batch_group_shuffle(self):
# if ok_ljspeech:
# dataset = TTSDatasetMemory.MyDataset(
# c.data_path_cache,
# 'tts_metadata.csv',
# c.r,
# c.text_cleaner,
# preprocessor=ljspeech,
# ap=self.ap,
# batch_group_size=16,
# min_seq_len=c.min_seq_len)
# Test for batch size 1
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
# dataloader = DataLoader(
# dataset,
# batch_size=2,
# shuffle=True,
# collate_fn=dataset.collate_fn,
# drop_last=True,
# num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# frames = dataset.items
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0
# assert mel_input[0, -2].sum() != 0
assert linear_input[0, -1].sum() == 0
# assert linear_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == mel_input[0].shape[0]
# neg_values = text_input[text_input < 0]
# check_count = len(neg_values)
# assert check_count == 0, \
# " !! Negative values in text_input: {}".format(check_count)
# assert mel_input.shape[0] == c.batch_size
# assert mel_input.shape[2] == c.audio['num_mels']
# dataloader.dataset.sort_items()
# assert frames[0] != dataloader.dataset.items[0]
# Test for batch size 2
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
# def test_padding_and_spec(self):
# if ok_ljspeech:
# dataset = TTSDatasetMemory.MyDataset(
# c.data_path_cache,
# 'tts_meta_data.csv',
# 1,
# c.text_cleaner,
# preprocessor=ljspeech,
# ap=self.ap,
# min_seq_len=c.min_seq_len)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# # Test for batch size 1
# dataloader = DataLoader(
# dataset,
# batch_size=1,
# shuffle=False,
# collate_fn=dataset.collate_fn,
# drop_last=True,
# num_workers=c.num_loader_workers)
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# check the first item in the batch
assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
# # check mel_spec consistency
# if item_idx[0].split('.')[-1] == 'npy':
# wav = np.load(item_idx[0])
# else:
# wav = self.ap.load_wav(item_idx[0])
# mel = self.ap.melspectrogram(wav)
# mel_dl = mel_input[0].cpu().numpy()
# assert (
# abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0
# check the second itme in the batch
assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# # check mel-spec correctness
# mel_spec = mel_input[0].cpu().numpy()
# wav = self.ap.inv_mel_spectrogram(mel_spec.T)
# self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader_memo.wav')
# shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_memo.wav')
# check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# # check the last time step to be zero padded
# assert mel_input[0, -1].sum() == 0
# assert mel_input[0, -2].sum() != 0
# assert stop_target[0, -1] == 1
# assert stop_target[0, -2] == 0
# assert stop_target.sum() == 1
# assert len(mel_lengths.shape) == 1
# assert mel_lengths[0] == mel_input[0].shape[0]
# # Test for batch size 2
# dataloader = DataLoader(
# dataset,
# batch_size=2,
# shuffle=False,
# collate_fn=dataset.collate_fn,
# drop_last=False,
# num_workers=c.num_loader_workers)
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# if mel_lengths[0] > mel_lengths[1]:
# idx = 0
# else:
# idx = 1
# # check the first item in the batch
# assert mel_input[idx, -1].sum() == 0
# assert mel_input[idx, -2].sum() != 0, mel_input
# assert stop_target[idx, -1] == 1
# assert stop_target[idx, -2] == 0
# assert stop_target[idx].sum() == 1
# assert len(mel_lengths.shape) == 1
# assert mel_lengths[idx] == mel_input[idx].shape[0]
# # check the second itme in the batch
# assert mel_input[1 - idx, -1].sum() == 0
# assert stop_target[1 - idx, -1] == 1
# assert len(mel_lengths.shape) == 1
# # check batch conditions
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0

View File

@ -21,8 +21,8 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
class TacotronTrainTest(unittest.TestCase):
def test_train_step(self):
input = torch.randint(0, 24, (8, 128)).long().to(device)
mel_spec = torch.rand(8, 30, c.num_mels).to(device)
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
@ -35,7 +35,7 @@ class TacotronTrainTest(unittest.TestCase):
criterion = L1LossMasked().to(device)
criterion_st = nn.BCELoss().to(device)
model = Tacotron(c.embedding_size, c.num_freq, c.num_mels,
model = Tacotron(c.embedding_size, c.audio['num_freq'], c.audio['num_mels'],
c.r).to(device)
model.train()
model_ref = copy.deepcopy(model)

View File

@ -1,16 +1,25 @@
{
"num_mels": 80,
"num_freq": 1025,
"sample_rate": 22050,
"frame_length_ms": 50,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"audio":{
"audio_processor": "audio", // to use dictate different audio processors, if available.
"num_mels": 80, // size of the mel spec frame.
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled.
"frame_length_ms": 50, // stft window length in ms.
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
"preemphasis": 0.97, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 30,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": true, // move normalization to range [-1, 1]
"clip_norm": true, // clip normalized values into the range.
"max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"mel_fmin": 95, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 7600 // maximum freq level for mel-spec. Tune for dataset!!
},
"hidden_size": 128,
"embedding_size": 256,
"min_mel_freq": null,
"max_mel_freq": null,
"text_cleaner": "english_cleaners",
"epochs": 2000,
@ -21,16 +30,11 @@
"r": 5,
"mk": 1.0,
"priority_freq": false,
"griffin_lim_iters": 60,
"power": 1.5,
"num_loader_workers": 4,
"save_step": 200,
"data_path_LJSpeech": "/home/erogol/Data/LJSpeech-1.1",
"data_path_Kusal": "/home/erogol/Data/Kusal",
"data_path": "/home/erogol/Data/LJSpeech-1.1/",
"data_path_cache": "/home/erogol/Data/LJSpeech-1.1/tts_cache/",
"output_path": "result",
"min_seq_len": 0,
"log_dir": "/home/erogol/projects/TTS/logs/"

View File

@ -14,16 +14,16 @@ from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from utils.generic_utils import (
synthesis, remove_experiment_folder, create_experiment_folder,
remove_experiment_folder, create_experiment_folder,
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
check_update, get_commit_hash, sequence_mask, AnnealLR)
from utils.visual import plot_alignment, plot_spectrogram
from models.tacotron import Tacotron
from layers.losses import L1LossMasked
from utils.audio import AudioProcessor
from utils.synthesis import synthesis
torch.manual_seed(1)
# torch.set_num_threads(4)
use_cuda = torch.cuda.is_available()
@ -36,7 +36,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_stop_loss = 0
avg_step_time = 0
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -215,7 +215,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist."
]
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
with torch.no_grad():
if data_loader is not None:
for num_iter, data in enumerate(data_loader):
@ -277,6 +277,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
const_spec = linear_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, ap)
gt_spec = plot_spectrogram(gt_spec, ap)
align_img = plot_alignment(align_img)
@ -319,8 +320,8 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
ap.griffin_lim_iters = 60
for idx, test_sentence in enumerate(test_sentences):
try:
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
use_cuda, c.text_cleaner)
wav, alignment, linear_spec, stop_tokens = synthesis(model, test_sentence, c,
use_cuda, ap)
file_path = os.path.join(AUDIO_PATH, str(current_step))
os.makedirs(file_path, exist_ok=True)
@ -330,7 +331,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
wav_name = 'TestSentences/{}'.format(idx)
tb.add_audio(
wav_name, wav, current_step, sample_rate=c.sample_rate)
wav_name, wav, current_step, sample_rate=c.audio['sample_rate'])
align_img = alignments[0].data.cpu().numpy()
linear_spec = plot_spectrogram(linear_spec, ap)
align_img = plot_alignment(align_img)
@ -345,28 +346,22 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
def main(args):
dataset = importlib.import_module('datasets.' + c.dataset)
Dataset = getattr(dataset, 'MyDataset')
audio = importlib.import_module('utils.' + c.audio_processor)
preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, c.dataset.lower())
MyDataset = importlib.import_module('datasets.'+c.data_loader)
MyDataset = getattr(MyDataset, "MyDataset")
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
AudioProcessor = getattr(audio, 'AudioProcessor')
ap = AudioProcessor(
sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis)
ap = AudioProcessor(**c.audio)
# Setup the dataset
train_dataset = Dataset(
train_dataset = MyDataset(
c.data_path,
c.meta_file_train,
c.r,
c.text_cleaner,
preprocessor=preprocessor,
ap=ap,
batch_group_size=8*c.batch_size,
min_seq_len=c.min_seq_len)
@ -381,8 +376,8 @@ def main(args):
pin_memory=True)
if c.run_eval:
val_dataset = Dataset(
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap, batch_group_size=0)
val_dataset = MyDataset(
c.data_path, c.meta_file_val, c.r, c.text_cleaner, preprocessor=preprocessor, ap=ap, batch_group_size=0)
val_loader = DataLoader(
val_dataset,
@ -395,7 +390,7 @@ def main(args):
else:
val_loader = None
model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r)
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
@ -431,7 +426,7 @@ def main(args):
criterion.cuda()
criterion_st.cuda()
scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps, last_epoch= (args.restore_step-1))
scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
num_params = count_parameters(model)
print(" | > Model has {} parameters".format(num_params), flush=True)
@ -454,7 +449,7 @@ def main(args):
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
OUT_PATH, current_step, epoch)
# shuffle batch groups
train_loader.dataset.sort_frames()
train_loader.dataset.sort_items()
if __name__ == '__main__':
@ -477,8 +472,9 @@ if __name__ == '__main__':
parser.add_argument(
'--data_path',
type=str,
default='',
help='data path to overrite config.json')
help='dataset path.',
default=''
)
args = parser.parse_args()
# setup output paths and read configs
@ -491,7 +487,7 @@ if __name__ == '__main__':
os.makedirs(AUDIO_PATH, exist_ok=True)
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
if args.data_path != "":
if args.data_path != '':
c.data_path = args.data_path
# setup tensorboard

View File

@ -3,26 +3,34 @@ import librosa
import pickle
import copy
import numpy as np
import scipy
from scipy import signal
_mel_basis = None
from pprint import pprint
from scipy import signal, io
class AudioProcessor(object):
def __init__(self,
sample_rate,
num_mels,
min_level_db,
frame_shift_ms,
frame_length_ms,
ref_level_db,
num_freq,
power,
preemphasis,
griffin_lim_iters=None):
bits=None,
sample_rate=None,
num_mels=None,
min_level_db=None,
frame_shift_ms=None,
frame_length_ms=None,
ref_level_db=None,
num_freq=None,
power=None,
preemphasis=None,
signal_norm=None,
symmetric_norm=None,
max_norm=None,
mel_fmin=None,
mel_fmax=None,
clip_norm=True,
griffin_lim_iters=None,
**kwargs):
print(" > Setting up Audio Processor...")
self.bits = bits
self.sample_rate = sample_rate
self.num_mels = num_mels
self.min_level_db = min_level_db
@ -33,33 +41,79 @@ class AudioProcessor(object):
self.power = power
self.preemphasis = preemphasis
self.griffin_lim_iters = griffin_lim_iters
self.signal_norm = signal_norm
self.symmetric_norm = symmetric_norm
self.mel_fmin = 0 if mel_fmin is None else mel_fmin
self.mel_fmax = mel_fmax
self.max_norm = 1.0 if max_norm is None else float(max_norm)
self.clip_norm = clip_norm
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
if preemphasis == 0:
print(" | > Preemphasis is deactive.")
print(" | > Audio Processor attributes.")
members = vars(self)
pprint(members)
def save_wav(self, wav, path):
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
# librosa.output.write_wav(path, wav_norm.astype(np.int16), self.sample_rate)
scipy.io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16))
io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16))
def _linear_to_mel(self, spectrogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = self._build_mel_basis()
_mel_basis = self._build_mel_basis()
return np.dot(_mel_basis, spectrogram)
def _mel_to_linear(self, mel_spec):
inv_mel_basis = np.linalg.pinv(self._build_mel_basis())
return np.maximum(1e-10, np.dot(inv_mel_basis, mel_spec))
def _build_mel_basis(self, ):
n_fft = (self.num_freq - 1) * 2
if self.mel_fmax is not None:
assert self.mel_fmax <= self.sample_rate // 2
return librosa.filters.mel(
self.sample_rate, n_fft, n_mels=self.num_mels)
self.sample_rate,
n_fft,
n_mels=self.num_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax)
def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
"""Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]"""
if self.signal_norm:
S_norm = ((S - self.min_level_db) / - self.min_level_db)
if self.symmetric_norm:
S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm
if self.clip_norm :
S_norm = np.clip(S_norm, -self.max_norm, self.max_norm)
return S_norm
else:
S_norm = self.max_norm * S_norm
if self.clip_norm:
S_norm = np.clip(S_norm, 0, self.max_norm)
return S_norm
else:
return S
def _denormalize(self, S):
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
"""denormalize values"""
S_denorm = S
if self.signal_norm:
if self.symmetric_norm:
if self.clip_norm:
S_denorm = np.clip(S_denorm, -self.max_norm, self.max_norm)
S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db
return S_denorm
else:
if self.clip_norm:
S_denorm = np.clip(S_denorm, 0, self.max_norm)
S_denorm = (S_denorm * -self.min_level_db /
self.max_norm) + self.min_level_db
return S_denorm
else:
return S
def _stft_parameters(self, ):
"""Compute necessary stft parameters with given time values"""
n_fft = (self.num_freq - 1) * 2
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
@ -92,8 +146,16 @@ class AudioProcessor(object):
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
return self._normalize(S)
def melspectrogram(self, y):
if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
return self._normalize(S)
def inv_spectrogram(self, spectrogram):
'''Converts spectrogram to waveform using librosa'''
"""Converts spectrogram to waveform using librosa"""
S = self._denormalize(spectrogram)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase
@ -102,6 +164,16 @@ class AudioProcessor(object):
else:
return self._griffin_lim(S**self.power)
def inv_mel_spectrogram(self, mel_spectrogram):
'''Converts mel spectrogram to waveform using librosa'''
D = self._denormalize(mel_spectrogram)
S = self._db_to_amp(D + self.ref_level_db)
S = self._mel_to_linear(S) # Convert back to linear
if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
else:
return self._griffin_lim(S**self.power)
def _griffin_lim(self, S):
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex)
@ -111,20 +183,17 @@ class AudioProcessor(object):
y = self._istft(S_complex * angles)
return y
def melspectrogram(self, y):
if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
return self._normalize(S)
def _stft(self, y):
return librosa.stft(
y=y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length)
y=y,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
)
def _istft(self, y):
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
return librosa.istft(
y, hop_length=self.hop_length, win_length=self.win_length)
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(self.sample_rate * min_silence_sec)
@ -134,3 +203,37 @@ class AudioProcessor(object):
if np.max(wav[x:x + window_length]) < threshold:
return x + hop_length
return len(wav)
# WaveRNN repo specific functions
# def mulaw_encode(self, wav, qc):
# mu = qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0)
# magnitude = np.log(1 + mu * wav_abs) / np.log(1. + mu)
# signal = np.sign(wav) * magnitude
# # Quantize signal to the specified number of levels.
# signal = (signal + 1) / 2 * mu + 0.5
# return signal.astype(np.int32)
# def mulaw_decode(self, wav, qc):
# """Recovers waveform from quantized values."""
# mu = qc - 1
# # Map values back to [-1, 1].
# casted = wav.astype(np.float32)
# signal = 2 * (casted / mu) - 1
# # Perform inverse of mu-law transformation.
# magnitude = (1 / mu) * ((1 + mu) ** abs(signal) - 1)
# return np.sign(signal) * magnitude
def load_wav(self, filename, encode=False):
x, sr = librosa.load(filename, sr=self.sample_rate)
assert self.sample_rate == sr
return x
def encode_16bits(self, x):
return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
def quantize(self, x):
return (x + 1.) * (2**self.bits - 1) / 2
def dequantize(self, x):
return 2 * x / (2**self.bits - 1) - 1

View File

@ -1,4 +1,5 @@
import os
import re
import sys
import glob
import time
@ -21,18 +22,23 @@ class AttrDict(dict):
def load_config(config_path):
config = AttrDict()
config.update(json.load(open(config_path, "r")))
with open(config_path, "r") as f:
input_str = f.read()
input_str = re.sub(r'\\\n', '', input_str)
input_str = re.sub(r'//.*\n', '\n', input_str)
data = json.loads(input_str)
config.update(data)
return config
def get_commit_hash():
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
try:
subprocess.check_output(['git', 'diff-index', '--quiet',
'HEAD']) # Verify client is clean
except:
raise RuntimeError(
" !! Commit before training to get the commit hash.")
# try:
# subprocess.check_output(['git', 'diff-index', '--quiet',
# 'HEAD']) # Verify client is clean
# except:
# raise RuntimeError(
# " !! Commit before training to get the commit hash.")
commit = subprocess.check_output(['git', 'rev-parse', '--short',
'HEAD']).decode().strip()
print(' > Git Hash: {}'.format(commit))
@ -177,15 +183,3 @@ def sequence_mask(sequence_length, max_len=None):
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand))
return seq_range_expand < seq_length_expand
def synthesis(model, ap, text, use_cuda, text_cleaner):
text_cleaner = [text_cleaner]
seq = np.array(text_to_sequence(text, text_cleaner))
chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda:
chars_var = chars_var.cuda().long()
_, linear_out, alignments, _ = model.forward(chars_var)
linear_out = linear_out[0].data.cpu().numpy()
wav = ap.inv_spectrogram(linear_out.T)
return wav, linear_out, alignments

23
utils/synthesis.py Normal file
View File

@ -0,0 +1,23 @@
import io
import time
import librosa
import torch
import numpy as np
from .text import text_to_sequence
from .visual import visualize
from matplotlib import pylab as plt
def synthesis(m, s, CONFIG, use_cuda, ap):
""" Given the text, synthesising the audio """
text_cleaner = [CONFIG.text_cleaner]
seq = np.array(text_to_sequence(s, text_cleaner))
chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda:
chars_var = chars_var.cuda()
mel_spec, linear_spec, alignments, stop_tokens = m.forward(chars_var.long())
linear_spec = linear_spec[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()
wav = ap.inv_spectrogram(linear_spec.T)
# wav = wav[:ap.find_endpoint(wav)]
return wav, alignment, linear_spec, stop_tokens

View File

@ -1,4 +1,5 @@
import numpy as np
import librosa
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
@ -14,6 +15,7 @@ def plot_alignment(alignment, info=None):
xlabel += '\n\n' + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
# plt.yticks(range(len(text)), list(text))
plt.tight_layout()
return fig
@ -25,3 +27,28 @@ def plot_spectrogram(linear_output, audio):
plt.colorbar()
plt.tight_layout()
return fig
def visualize(alignment, spectrogram, stop_tokens, text, hop_length, CONFIG):
label_fontsize = 16
plt.figure(figsize=(16, 32))
plt.subplot(3, 1, 1)
plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None)
plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
plt.yticks(range(len(text)), list(text))
plt.colorbar()
stop_tokens = stop_tokens.squeeze().detach().to('cpu').numpy()
plt.subplot(3, 1, 2)
plt.plot(range(len(stop_tokens)), list(stop_tokens))
plt.subplot(3, 1, 3)
librosa.display.specshow(spectrogram.T, sr=CONFIG.audio['sample_rate'],
hop_length=hop_length, x_axis="time", y_axis="linear")
plt.xlabel("Time", fontsize=label_fontsize)
plt.ylabel("Hz", fontsize=label_fontsize)
plt.tight_layout()
plt.colorbar()