mirror of https://github.com/coqui-ai/TTS.git
Batch update after data-loss
parent
f53f9cb360
commit
c8a552e627
70
config.json
70
config.json
|
@ -1,41 +1,47 @@
|
|||
{
|
||||
"model_name": "TTS-sigmoid",
|
||||
"model_description": "Net outputting Sigmoid unit",
|
||||
"audio_processor": "audio",
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"sample_rate": 22000,
|
||||
"frame_length_ms": 50,
|
||||
"frame_shift_ms": 12.5,
|
||||
"preemphasis": 0.97,
|
||||
"min_level_db": -100,
|
||||
"ref_level_db": 20,
|
||||
"model_name": "TTS-master-_tmp",
|
||||
"model_description": "Higher dropout rate for stopnet and disabled custom initialization, pull current mel prediction to stopnet.",
|
||||
|
||||
"audio":{
|
||||
"audio_processor": "audio", // to use dictate different audio processors, if available.
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled.
|
||||
"frame_length_ms": 50, // stft window length in ms.
|
||||
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
|
||||
"preemphasis": 0.97, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"min_level_db": -100, // normalization range
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||
"symmetric_norm": false, // move normalization to range [-1, 1]
|
||||
"max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"mel_fmin": null, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": null // maximum freq level for mel-spec. Tune for dataset!!
|
||||
},
|
||||
|
||||
"embedding_size": 256,
|
||||
"text_cleaner": "english_cleaners",
|
||||
|
||||
"num_loader_workers": 4,
|
||||
|
||||
"epochs": 1000,
|
||||
"lr": 0.002,
|
||||
"lr": 0.0015,
|
||||
"warmup_steps": 4000,
|
||||
"lr_decay": 0.5,
|
||||
"decay_step": 100000,
|
||||
"batch_size": 32,
|
||||
"eval_batch_size":-1,
|
||||
"r": 5,
|
||||
"wd": 0.0001,
|
||||
|
||||
"griffin_lim_iters": 60,
|
||||
"power": 1.5,
|
||||
|
||||
"batch_size":32,
|
||||
"eval_batch_size":32,
|
||||
"r": 1,
|
||||
"wd": 0.000001,
|
||||
"checkpoint": true,
|
||||
"save_step": 25000,
|
||||
"save_step": 5000,
|
||||
"print_step": 10,
|
||||
"run_eval": false,
|
||||
"data_path": "/snakepit/shared/data/keithito/LJSpeech-1.1/",
|
||||
"meta_file_train": "metadata.csv",
|
||||
"meta_file_val": null,
|
||||
"dataset": "LJSpeech",
|
||||
|
||||
"run_eval": true,
|
||||
"data_path": "../../Data/LJSpeech-1.1/tts_cache", // can overwritten from command argument
|
||||
"meta_file_train": "metadata_train.csv", // metafile for training dataloader
|
||||
"meta_file_val": "metadata_val.csv", // metafile for validation dataloader
|
||||
"data_loader": "TTSDataset", // dataloader, ["TTSDataset", "TTSDatasetCached", "TTSDatasetMemory"]
|
||||
"dataset": "ljspeech", // one of TTS.dataset.preprocessors, only valid id dataloader == "TTSDataset", rest uses "tts_cache" by default.
|
||||
"min_seq_len": 0,
|
||||
"output_path": "../keep/"
|
||||
"output_path": "../keep/",
|
||||
"num_loader_workers": 8
|
||||
}
|
|
@ -13,75 +13,72 @@ from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
|||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self,
|
||||
root_dir,
|
||||
csv_file,
|
||||
root_path,
|
||||
meta_file,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
ap,
|
||||
preprocessor,
|
||||
batch_group_size=0,
|
||||
min_seq_len=0):
|
||||
self.root_dir = root_dir
|
||||
self.root_path = root_path
|
||||
self.batch_group_size = batch_group_size
|
||||
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||
self.frames = [line.split('|') for line in f]
|
||||
self.items = preprocessor(root_path, meta_file)
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.min_seq_len = min_seq_len
|
||||
self.ap = ap
|
||||
print(" > Reading LJSpeech from - {}".format(root_dir))
|
||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||
self.sort_frames()
|
||||
print(" > Reading LJSpeech from - {}".format(root_path))
|
||||
print(" | > Number of instances : {}".format(len(self.items)))
|
||||
self.sort_items()
|
||||
|
||||
def load_wav(self, filename):
|
||||
try:
|
||||
audio = librosa.core.load(filename, sr=self.sample_rate)[0]
|
||||
audio = self.ap.load_wav(filename)
|
||||
return audio
|
||||
except RuntimeError as e:
|
||||
print(" !! Cannot read file : {}".format(filename))
|
||||
|
||||
def sort_frames(self):
|
||||
def sort_items(self):
|
||||
r"""Sort text sequences in ascending order"""
|
||||
lengths = np.array([len(ins[1]) for ins in self.frames])
|
||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||
|
||||
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||
|
||||
idxs = np.argsort(lengths)
|
||||
new_frames = []
|
||||
new_items = []
|
||||
ignored = []
|
||||
for i, idx in enumerate(idxs):
|
||||
length = lengths[idx]
|
||||
if length < self.min_seq_len:
|
||||
ignored.append(idx)
|
||||
else:
|
||||
new_frames.append(self.frames[idx])
|
||||
new_items.append(self.items[idx])
|
||||
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||
len(ignored), self.min_seq_len))
|
||||
# shuffle batch groups
|
||||
if self.batch_group_size > 0:
|
||||
print(" | > Batch group shuffling is active.")
|
||||
for i in range(len(new_frames) // self.batch_group_size):
|
||||
for i in range(len(new_items) // self.batch_group_size):
|
||||
offset = i * self.batch_group_size
|
||||
end_offset = offset + self.batch_group_size
|
||||
temp_frames = new_frames[offset : end_offset]
|
||||
random.shuffle(temp_frames)
|
||||
new_frames[offset : end_offset] = temp_frames
|
||||
self.frames = new_frames
|
||||
temp_items = new_items[offset : end_offset]
|
||||
random.shuffle(temp_items)
|
||||
new_items[offset : end_offset] = temp_items
|
||||
self.items = new_items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frames)
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
|
||||
text = self.frames[idx][1]
|
||||
text, wav_file = self.items[idx]
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]}
|
||||
return sample
|
||||
|
||||
def collate_fn(self, batch):
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import collections
|
||||
import librosa
|
||||
|
@ -6,32 +7,34 @@ import torch
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from utils.text import text_to_sequence
|
||||
from datasets.preprocess import tts_cache
|
||||
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||
prepare_stop_target)
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
# TODO: Not finished yet.
|
||||
def __init__(self,
|
||||
root_dir,
|
||||
csv_file,
|
||||
root_path,
|
||||
meta_file,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
ap,
|
||||
min_seq_len=0):
|
||||
self.root_dir = root_dir
|
||||
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||
self.feat_dir = os.path.join(root_dir, 'loader_data')
|
||||
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||
self.frames = [line.split('|') for line in f]
|
||||
batch_group_size=0,
|
||||
min_seq_len=0,
|
||||
**kwargs
|
||||
):
|
||||
self.root_path = root_path
|
||||
self.batch_group_size = batch_group_size
|
||||
self.feat_dir = os.path.join(root_path, 'loader_data')
|
||||
self.items = tts_cache(root_path, meta_file)
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.min_seq_len = min_seq_len
|
||||
self.items = [None] * len(self.frames)
|
||||
print(" > Reading LJSpeech from - {}".format(root_dir))
|
||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||
self._sort_frames()
|
||||
print(" > Reading LJSpeech from - {}".format(root_path))
|
||||
print(" | > Number of instances : {}".format(len(self.items)))
|
||||
self.sort_items()
|
||||
|
||||
def load_wav(self, filename):
|
||||
try:
|
||||
|
@ -44,9 +47,9 @@ class MyDataset(Dataset):
|
|||
data = np.load(filename).astype('float32')
|
||||
return data
|
||||
|
||||
def _sort_frames(self):
|
||||
r"""Sort sequences in ascending order"""
|
||||
lengths = np.array([len(ins[1]) for ins in self.frames])
|
||||
def sort_items(self):
|
||||
r"""Sort text sequences in ascending order"""
|
||||
lengths = np.array([len(ins[-1]) for ins in self.items])
|
||||
|
||||
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||
|
@ -60,37 +63,43 @@ class MyDataset(Dataset):
|
|||
if length < self.min_seq_len:
|
||||
ignored.append(idx)
|
||||
else:
|
||||
new_frames.append(self.frames[idx])
|
||||
new_frames.append(self.items[idx])
|
||||
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||
len(ignored), self.min_seq_len))
|
||||
self.frames = new_frames
|
||||
# shuffle batch groups
|
||||
if self.batch_group_size > 0:
|
||||
print(" | > Batch group shuffling is active.")
|
||||
for i in range(len(new_frames) // self.batch_group_size):
|
||||
offset = i * self.batch_group_size
|
||||
end_offset = offset + self.batch_group_size
|
||||
temp_frames = new_frames[offset : end_offset]
|
||||
random.shuffle(temp_frames)
|
||||
new_frames[offset : end_offset] = temp_frames
|
||||
self.items = new_frames
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frames)
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.items[idx] is None:
|
||||
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
|
||||
mel_name = os.path.join(self.feat_dir,
|
||||
self.frames[idx][0]) + '.mel.npy'
|
||||
linear_name = os.path.join(self.feat_dir,
|
||||
self.frames[idx][0]) + '.linear.npy'
|
||||
text = self.frames[idx][1]
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||
mel = self.load_np(mel_name)
|
||||
linear = self.load_np(linear_name)
|
||||
sample = {
|
||||
'text': text,
|
||||
'wav': wav,
|
||||
'item_idx': self.frames[idx][0],
|
||||
'mel': mel,
|
||||
'linear': linear
|
||||
}
|
||||
self.items[idx] = sample
|
||||
wav_name = self.items[idx][0]
|
||||
mel_name = self.items[idx][1]
|
||||
linear_name = self.items[idx][2]
|
||||
text = self.items[idx][-1]
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
if wav_name.split('.')[-1] == 'npy':
|
||||
wav = self.load_np(wav_name)
|
||||
else:
|
||||
sample = self.items[idx]
|
||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||
mel = self.load_np(mel_name)
|
||||
linear = self.load_np(linear_name)
|
||||
sample = {
|
||||
'text': text,
|
||||
'wav': wav,
|
||||
'item_idx': self.items[idx][0],
|
||||
'mel': mel,
|
||||
'linear': linear
|
||||
}
|
||||
return sample
|
||||
|
||||
def collate_fn(self, batch):
|
||||
|
@ -132,7 +141,6 @@ class MyDataset(Dataset):
|
|||
# PAD features with largest length + a zero frame
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||
assert mel.shape[2] == linear.shape[2]
|
||||
timesteps = mel.shape[2]
|
||||
|
||||
# B x T x D
|
||||
|
@ -147,8 +155,7 @@ class MyDataset(Dataset):
|
|||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||
0]
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}".format(type(batch[0]))))
|
|
@ -1,162 +1,179 @@
|
|||
import os
|
||||
import glob
|
||||
import random
|
||||
import numpy as np
|
||||
import collections
|
||||
import librosa
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from utils.text import text_to_sequence
|
||||
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||
prepare_stop_target)
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self,
|
||||
root_dir,
|
||||
csv_file,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
ap,
|
||||
min_seq_len=0):
|
||||
self.root_dir = root_dir
|
||||
self.wav_dir = os.path.join(root_dir, 'wav')
|
||||
self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav'))
|
||||
self._create_file_dict()
|
||||
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||
self.frames = [
|
||||
line.split('\t') for line in f
|
||||
if line.split('\t')[0] in self.wav_files_dict.keys()
|
||||
]
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.min_seq_len = min_seq_len
|
||||
self.ap = ap
|
||||
print(" > Reading Kusal from - {}".format(root_dir))
|
||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||
self._sort_frames()
|
||||
|
||||
def load_wav(self, filename):
|
||||
""" Load audio and trim silence """
|
||||
try:
|
||||
audio = librosa.core.load(filename, sr=self.sample_rate)[0]
|
||||
margin = int(self.sample_rate * 0.1)
|
||||
audio = audio[margin:-margin]
|
||||
return self._trim_silence(audio)
|
||||
except RuntimeError as e:
|
||||
print(" !! Cannot read file : {}".format(filename))
|
||||
|
||||
def _trim_silence(self, wav):
|
||||
return librosa.effects.trim(
|
||||
wav, top_db=40, frame_length=1024, hop_length=256)[0]
|
||||
|
||||
def _create_file_dict(self):
|
||||
self.wav_files_dict = {}
|
||||
for fn in self.wav_files:
|
||||
parts = fn.split('-')
|
||||
key = parts[1]
|
||||
value = fn
|
||||
try:
|
||||
self.wav_files_dict[key].append(value)
|
||||
except:
|
||||
self.wav_files_dict[key] = [value]
|
||||
|
||||
def _sort_frames(self):
|
||||
r"""Sort sequences in ascending order"""
|
||||
lengths = np.array([len(ins[2]) for ins in self.frames])
|
||||
|
||||
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||
|
||||
idxs = np.argsort(lengths)
|
||||
new_frames = []
|
||||
ignored = []
|
||||
for i, idx in enumerate(idxs):
|
||||
length = lengths[idx]
|
||||
if length < self.min_seq_len:
|
||||
ignored.append(idx)
|
||||
else:
|
||||
new_frames.append(self.frames[idx])
|
||||
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||
len(ignored), self.min_seq_len))
|
||||
self.frames = new_frames
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frames)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sidx = self.frames[idx][0]
|
||||
sidx_files = self.wav_files_dict[sidx]
|
||||
file_name = random.choice(sidx_files)
|
||||
wav_name = os.path.join(self.wav_dir, file_name)
|
||||
text = self.frames[idx][2]
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||
return sample
|
||||
|
||||
def collate_fn(self, batch):
|
||||
r"""
|
||||
Perform preprocessing and create a final data batch:
|
||||
1. PAD sequences with the longest sequence in the batch
|
||||
2. Convert Audio signal to Spectrograms.
|
||||
3. PAD sequences that can be divided by r.
|
||||
4. Convert Numpy to Torch tensors.
|
||||
"""
|
||||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.Mapping):
|
||||
keys = list()
|
||||
|
||||
wav = [d['wav'] for d in batch]
|
||||
item_idxs = [d['item_idx'] for d in batch]
|
||||
text = [d['text'] for d in batch]
|
||||
|
||||
text_lenghts = np.array([len(x) for x in text])
|
||||
max_text_len = np.max(text_lenghts)
|
||||
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [
|
||||
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||
]
|
||||
|
||||
# PAD stop targets
|
||||
stop_targets = prepare_stop_target(stop_targets,
|
||||
self.outputs_per_step)
|
||||
|
||||
# PAD sequences with largest length of the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
wav = prepare_data(wav)
|
||||
|
||||
# PAD features with largest length + a zero frame
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||
assert mel.shape[2] == linear.shape[2]
|
||||
timesteps = mel.shape[2]
|
||||
|
||||
# B x T x D
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
# convert things to pytorch
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text = torch.LongTensor(text)
|
||||
linear = torch.FloatTensor(linear)
|
||||
mel = torch.FloatTensor(mel)
|
||||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||
0]
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}".format(type(batch[0]))))
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import collections
|
||||
import librosa
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from utils.text import text_to_sequence
|
||||
from datasets.preprocess import tts_cache
|
||||
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||
prepare_stop_target)
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
# TODO: Not finished yet.
|
||||
def __init__(self,
|
||||
root_path,
|
||||
meta_file,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
ap,
|
||||
batch_group_size=0,
|
||||
min_seq_len=0,
|
||||
**kwargs
|
||||
):
|
||||
self.root_path = root_path
|
||||
self.batch_group_size = batch_group_size
|
||||
self.feat_dir = os.path.join(root_path, 'loader_data')
|
||||
self.items = tts_cache(root_path, meta_file)
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.min_seq_len = min_seq_len
|
||||
self.wavs = None
|
||||
self.mels = None
|
||||
self.linears = None
|
||||
print(" > Reading LJSpeech from - {}".format(root_path))
|
||||
print(" | > Number of instances : {}".format(len(self.items)))
|
||||
self.sort_items()
|
||||
self.fill_data()
|
||||
|
||||
def fill_data(self):
|
||||
if self.wavs is None and self.mels is None:
|
||||
self.wavs = []
|
||||
self.mels = []
|
||||
self.linears = []
|
||||
self.texts = []
|
||||
for item in tqdm(self.items):
|
||||
wav_file = item[0]
|
||||
mel_file = item[1]
|
||||
linear_file = item[2]
|
||||
text = item[-1]
|
||||
wav = self.load_np(wav_file)
|
||||
mel = self.load_np(mel_file)
|
||||
linear = self.load_np(linear_file)
|
||||
self.wavs.append(wav)
|
||||
self.mels.append(mel)
|
||||
self.linears.append(linear)
|
||||
self.texts.append(np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32))
|
||||
print(" > Data loaded to memory")
|
||||
|
||||
def load_wav(self, filename):
|
||||
try:
|
||||
audio = librosa.core.load(filename, sr=self.sample_rate)
|
||||
return audio
|
||||
except RuntimeError as e:
|
||||
print(" !! Cannot read file : {}".format(filename))
|
||||
|
||||
def load_np(self, filename):
|
||||
data = np.load(filename).astype('float32')
|
||||
return data
|
||||
|
||||
def sort_items(self):
|
||||
r"""Sort text sequences in ascending order"""
|
||||
lengths = np.array([len(ins[-1]) for ins in self.items])
|
||||
|
||||
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||
|
||||
idxs = np.argsort(lengths)
|
||||
new_frames = []
|
||||
ignored = []
|
||||
for i, idx in enumerate(idxs):
|
||||
length = lengths[idx]
|
||||
if length < self.min_seq_len:
|
||||
ignored.append(idx)
|
||||
else:
|
||||
new_frames.append(self.items[idx])
|
||||
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||
len(ignored), self.min_seq_len))
|
||||
# shuffle batch groups
|
||||
if self.batch_group_size > 0:
|
||||
print(" | > Batch group shuffling is active.")
|
||||
for i in range(len(new_frames) // self.batch_group_size):
|
||||
offset = i * self.batch_group_size
|
||||
end_offset = offset + self.batch_group_size
|
||||
temp_frames = new_frames[offset : end_offset]
|
||||
random.shuffle(temp_frames)
|
||||
new_frames[offset : end_offset] = temp_frames
|
||||
self.items = new_frames
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
text = self.texts[idx]
|
||||
wav = self.wavs[idx]
|
||||
mel = self.mels[idx]
|
||||
linear = self.linears[idx]
|
||||
sample = {
|
||||
'text': text,
|
||||
'wav': wav,
|
||||
'item_idx': self.items[idx][0],
|
||||
'mel': mel,
|
||||
'linear': linear
|
||||
}
|
||||
return sample
|
||||
|
||||
def collate_fn(self, batch):
|
||||
r"""
|
||||
Perform preprocessing and create a final data batch:
|
||||
1. PAD sequences with the longest sequence in the batch
|
||||
2. Convert Audio signal to Spectrograms.
|
||||
3. PAD sequences that can be divided by r.
|
||||
4. Convert Numpy to Torch tensors.
|
||||
"""
|
||||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.Mapping):
|
||||
keys = list()
|
||||
|
||||
wav = [d['wav'] for d in batch]
|
||||
item_idxs = [d['item_idx'] for d in batch]
|
||||
text = [d['text'] for d in batch]
|
||||
mel = [d['mel'] for d in batch]
|
||||
linear = [d['linear'] for d in batch]
|
||||
|
||||
text_lenghts = np.array([len(x) for x in text])
|
||||
max_text_len = np.max(text_lenghts)
|
||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [
|
||||
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||
]
|
||||
|
||||
# PAD stop targets
|
||||
stop_targets = prepare_stop_target(stop_targets,
|
||||
self.outputs_per_step)
|
||||
|
||||
# PAD sequences with largest length of the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
wav = prepare_data(wav)
|
||||
|
||||
# PAD features with largest length + a zero frame
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||
timesteps = mel.shape[2]
|
||||
|
||||
# B x T x D
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
# convert things to pytorch
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text = torch.LongTensor(text)
|
||||
linear = torch.FloatTensor(linear)
|
||||
mel = torch.FloatTensor(mel)
|
||||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}".format(type(batch[0]))))
|
|
@ -0,0 +1,60 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
def tts_cache(root_path, meta_file):
|
||||
"""This format is set for the meta-file generated by extract_features.py"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, 'r', encoding='utf8') as f:
|
||||
for line in f:
|
||||
cols = line.split('| ')
|
||||
items.append(cols) # wav_full_path, mel_name, linear_name, wav_len, mel_len, text
|
||||
random.shuffle(items)
|
||||
return items
|
||||
|
||||
|
||||
# def tweb(root_path, meta_file):
|
||||
# # TODO
|
||||
# pass
|
||||
# return
|
||||
|
||||
|
||||
# def kusal(root_path, meta_file):
|
||||
# txt_file = os.path.join(root_path, meta_file)
|
||||
# texts = []
|
||||
# wavs = []
|
||||
# with open(txt_file, "r", encoding="utf8") as f:
|
||||
# frames = [
|
||||
# line.split('\t') for line in f
|
||||
# if line.split('\t')[0] in self.wav_files_dict.keys()
|
||||
# ]
|
||||
# # TODO: code the rest
|
||||
# return {'text': texts, 'wavs': wavs}
|
||||
|
||||
|
||||
def ljspeech(root_path, meta_file):
|
||||
"""Normalizes the Nancy meta data file to TTS format"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, 'r') as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('|')
|
||||
wav_file = os.path.join(root_path, 'wavs', cols[0]+'.wav')
|
||||
text = cols[1]
|
||||
items.append([text, wav_file])
|
||||
random.shuffle(items)
|
||||
return items
|
||||
|
||||
|
||||
def nancy(root_path, meta_file):
|
||||
"""Normalizes the Nancy meta data file to TTS format"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, 'r') as ttf:
|
||||
for line in ttf:
|
||||
id = line.split()[1]
|
||||
text = line[line.find('"')+1:line.rfind('"')-1]
|
||||
wav_file = root_path + 'wavn/' + id + '.wav'
|
||||
items.append(text, wav_file)
|
||||
random.shuffle(items)
|
||||
return items
|
|
@ -0,0 +1,138 @@
|
|||
'''
|
||||
Extract spectrograms and save them to file for training
|
||||
'''
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import glob
|
||||
import argparse
|
||||
import librosa
|
||||
import importlib
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from utils.generic_utils import load_config, copy_config_file
|
||||
from utils.audio import AudioProcessor
|
||||
|
||||
from multiprocessing import Pool
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_path', type=str, help='Data folder.')
|
||||
parser.add_argument('--cache_path', type=str, help='Cache folder, place to output all the intermediate spectrogram files.')
|
||||
# parser.add_argument('--keep_cache', type=bool, help='If True, it keeps the cache folder.')
|
||||
# parser.add_argument('--hdf5_path', type=str, help='hdf5 folder.')
|
||||
parser.add_argument(
|
||||
'--config', type=str, help='conf.json file for run settings.')
|
||||
parser.add_argument(
|
||||
"--num_proc", type=int, default=8, help="number of processes.")
|
||||
parser.add_argument(
|
||||
"--trim_silence",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="trim silence in the voice clip.")
|
||||
parser.add_argument("--only_mel", type=bool, default=False, help="If True, only melsceptrogram is extracted.")
|
||||
parser.add_argument("--dataset", type=str, help="Target dataset to be processed.")
|
||||
parser.add_argument("--val_split", type=int, default=0, help="Number of instances for validation.")
|
||||
parser.add_argument("--meta_file", type=str, help="Meta data file to be used for the dataset.")
|
||||
parser.add_argument("--process_audio", type=bool, default=False, help="Preprocess audio files.")
|
||||
args = parser.parse_args()
|
||||
|
||||
DATA_PATH = args.data_path
|
||||
CACHE_PATH = args.cache_path
|
||||
CONFIG = load_config(args.config)
|
||||
|
||||
# load the right preprocessor
|
||||
preprocessor = importlib.import_module('datasets.preprocess')
|
||||
preprocessor = getattr(preprocessor, args.dataset.lower())
|
||||
items = preprocessor(args.data_path, args.meta_file)
|
||||
|
||||
print(" > Input path: ", DATA_PATH)
|
||||
print(" > Cache path: ", CACHE_PATH)
|
||||
|
||||
# audio = importlib.import_module('utils.' + c.audio_processor)
|
||||
# AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||
ap = AudioProcessor(**CONFIG.audio)
|
||||
|
||||
def trim_silence(self, wav):
|
||||
""" Trim silent parts with a threshold and 0.1 sec margin """
|
||||
margin = int(ap.sample_rate * 0.1)
|
||||
wav = wav[margin:-margin]
|
||||
return librosa.effects.trim(
|
||||
wav, top_db=40, frame_length=1024, hop_length=256)[0]
|
||||
|
||||
def extract_mel(item):
|
||||
""" Compute spectrograms, length information """
|
||||
text = item[0]
|
||||
file_path = item[1]
|
||||
x = ap.load_wav(file_path, ap.sample_rate)
|
||||
if args.trim_silence:
|
||||
x = trim_silence(x)
|
||||
file_name = os.path.basename(file_path).replace(".wav", "")
|
||||
mel_file = file_name + "_mel"
|
||||
mel_path = os.path.join(CACHE_PATH, 'mel', mel_file)
|
||||
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
|
||||
np.save(mel_path, mel, allow_pickle=False)
|
||||
mel_len = mel.shape[1]
|
||||
wav_len = x.shape[0]
|
||||
output = [file_path, mel_path+".npy", str(wav_len), str(mel_len), text]
|
||||
if not args.only_mel:
|
||||
linear_file = file_name + "_linear"
|
||||
linear_path = os.path.join(CACHE_PATH, 'linear', linear_file)
|
||||
linear = ap.spectrogram(x.astype('float32')).astype('float32')
|
||||
linear_len = linear.shape[1]
|
||||
np.save(linear_path, linear, allow_pickle=False)
|
||||
output.insert(2, linear_path+".npy")
|
||||
if args.process_audio:
|
||||
audio_file = file_name + "_audio"
|
||||
audio_path = os.path.join(CACHE_PATH, 'audio', audio_file)
|
||||
np.save(audio_path, x, allow_pickle=False)
|
||||
del output[0]
|
||||
output.insert(0, audio_path+".npy")
|
||||
assert mel_len == linear_len
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(" > Number of files: %i" % (len(items)))
|
||||
if not os.path.exists(CACHE_PATH):
|
||||
os.makedirs(os.path.join(CACHE_PATH, 'mel'))
|
||||
if not args.only_mel:
|
||||
os.makedirs(os.path.join(CACHE_PATH, 'linear'))
|
||||
if args.process_audio:
|
||||
os.makedirs(os.path.join(CACHE_PATH, 'audio'))
|
||||
print(" > A new folder created at {}".format(CACHE_PATH))
|
||||
|
||||
# Extract features
|
||||
r = []
|
||||
if args.num_proc > 1:
|
||||
print(" > Using {} processes.".format(args.num_proc))
|
||||
with Pool(args.num_proc) as p:
|
||||
r = list(
|
||||
tqdm.tqdm(
|
||||
p.imap(extract_mel, items),
|
||||
total=len(items)))
|
||||
# r = list(p.imap(extract_mel, file_names))
|
||||
else:
|
||||
print(" > Using single process run.")
|
||||
for item in items:
|
||||
print(" > ", item[1])
|
||||
r.append(extract_mel(item))
|
||||
|
||||
# Save meta data
|
||||
if args.cache_path is not None:
|
||||
file_path = os.path.join(CACHE_PATH, "tts_metadata_val.csv")
|
||||
file = open(file_path, "w")
|
||||
for line in r[:args.val_split]:
|
||||
line = "| ".join(line)
|
||||
file.write(line + '\n')
|
||||
file.close()
|
||||
|
||||
file_path = os.path.join(CACHE_PATH, "tts_metadata.csv")
|
||||
file = open(file_path, "w")
|
||||
for line in r[args.val_split:]:
|
||||
line = "| ".join(line)
|
||||
file.write(line + '\n')
|
||||
file.close()
|
||||
|
||||
# copy the used config file to output path for sanity
|
||||
copy_config_file(args.config, CACHE_PATH)
|
|
@ -52,10 +52,25 @@ class LocationSensitiveAttention(nn.Module):
|
|||
stride=1,
|
||||
padding=0,
|
||||
bias=False))
|
||||
self.loc_linear = nn.Linear(filters, attn_dim)
|
||||
self.loc_linear = nn.Linear(filters, attn_dim, bias=True)
|
||||
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||
# self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.loc_linear.weight,
|
||||
gain=torch.nn.init.calculate_gain('tanh'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.query_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain('tanh'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.annot_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain('tanh'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.v.weight,
|
||||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def forward(self, annot, query, loc):
|
||||
"""
|
||||
|
|
|
@ -23,6 +23,13 @@ class Prenet(nn.Module):
|
|||
])
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
# self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
for layer in self.layers:
|
||||
torch.nn.init.xavier_uniform_(
|
||||
layer.weight,
|
||||
gain=torch.nn.init.calculate_gain('relu'))
|
||||
|
||||
def forward(self, inputs):
|
||||
for linear in self.layers:
|
||||
|
@ -55,6 +62,7 @@ class BatchNormConv1d(nn.Module):
|
|||
stride,
|
||||
padding,
|
||||
activation=None):
|
||||
|
||||
super(BatchNormConv1d, self).__init__()
|
||||
self.padding = padding
|
||||
self.padder = nn.ConstantPad1d(padding, 0)
|
||||
|
@ -68,13 +76,28 @@ class BatchNormConv1d(nn.Module):
|
|||
# Following tensorflow's default parameters
|
||||
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
|
||||
self.activation = activation
|
||||
# self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
if type(self.activation) == torch.nn.ReLU:
|
||||
w_gain = 'relu'
|
||||
elif type(self.activation) == torch.nn.Tanh:
|
||||
w_gain = 'tanh'
|
||||
elif self.activation is None:
|
||||
w_gain = 'linear'
|
||||
else:
|
||||
raise RuntimeError('Unknown activation function')
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv1d.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_gain))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.padder(x)
|
||||
x = self.conv1d(x)
|
||||
x = self.bn(x)
|
||||
if self.activation is not None:
|
||||
x = self.activation(x)
|
||||
return self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class Highway(nn.Module):
|
||||
|
@ -86,6 +109,15 @@ class Highway(nn.Module):
|
|||
self.T.bias.data.fill_(-1)
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
# self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.H.weight,
|
||||
gain=torch.nn.init.calculate_gain('relu'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.T.weight,
|
||||
gain=torch.nn.init.calculate_gain('sigmoid'))
|
||||
|
||||
def forward(self, inputs):
|
||||
H = self.relu(self.H(inputs))
|
||||
|
@ -276,7 +308,7 @@ class Decoder(nn.Module):
|
|||
super(Decoder, self).__init__()
|
||||
self.r = r
|
||||
self.in_features = in_features
|
||||
self.max_decoder_steps = 200
|
||||
self.max_decoder_steps = 500
|
||||
self.memory_dim = memory_dim
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||
|
@ -294,7 +326,16 @@ class Decoder(nn.Module):
|
|||
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
self.stopnet = StopNet(r, memory_dim)
|
||||
self.stopnet = StopNet(256 + memory_dim * r)
|
||||
# self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.project_to_decoder_in.weight,
|
||||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.proj_to_mel.weight,
|
||||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def forward(self, inputs, memory=None, mask=None):
|
||||
"""
|
||||
|
@ -350,7 +391,7 @@ class Decoder(nn.Module):
|
|||
memory_input = initial_memory
|
||||
while True:
|
||||
if t > 0:
|
||||
if greedy:
|
||||
if memory is None:
|
||||
memory_input = outputs[-1]
|
||||
else:
|
||||
memory_input = memory[t - 1]
|
||||
|
@ -375,17 +416,14 @@ class Decoder(nn.Module):
|
|||
decoder_output = decoder_input
|
||||
# predict mel vectors from decoder vectors
|
||||
output = self.proj_to_mel(decoder_output)
|
||||
output = torch.sigmoid(output)
|
||||
stop_input = output
|
||||
# predict stop token
|
||||
stop_token, stopnet_rnn_hidden = self.stopnet(
|
||||
stop_input, stopnet_rnn_hidden)
|
||||
stopnet_input = torch.cat([decoder_input, output], -1)
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
outputs += [output]
|
||||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
t += 1
|
||||
if (not greedy and self.training) or (greedy
|
||||
and memory is not None):
|
||||
if memory is not None:
|
||||
if t >= T_decoder:
|
||||
break
|
||||
else:
|
||||
|
@ -394,7 +432,6 @@ class Decoder(nn.Module):
|
|||
elif t > self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
assert greedy or len(outputs) == T_decoder
|
||||
# Back to batch first
|
||||
attentions = torch.stack(attentions).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
|
@ -407,32 +444,22 @@ class StopNet(nn.Module):
|
|||
Predicting stop-token in decoder.
|
||||
|
||||
Args:
|
||||
r (int): number of output frames of the network.
|
||||
memory_dim (int): feature dimension for each output frame.
|
||||
in_features (int): feature dimension of input.
|
||||
"""
|
||||
|
||||
def __init__(self, r, memory_dim):
|
||||
r"""
|
||||
Predicts the stop token to stop the decoder at testing time
|
||||
|
||||
Args:
|
||||
r (int): number of network output frames.
|
||||
memory_dim (int): single feature dim of a single network output frame.
|
||||
"""
|
||||
def __init__(self, in_features):
|
||||
super(StopNet, self).__init__()
|
||||
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear = nn.Linear(r * memory_dim, 1)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.linear = nn.Linear(in_features, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear.weight,
|
||||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def forward(self, inputs, rnn_hidden):
|
||||
"""
|
||||
Args:
|
||||
inputs: network output tensor with r x memory_dim feature dimension.
|
||||
rnn_hidden: hidden state of the RNN cell.
|
||||
"""
|
||||
rnn_hidden = self.rnn(inputs, rnn_hidden)
|
||||
outputs = self.relu(rnn_hidden)
|
||||
def forward(self, inputs):
|
||||
# rnn_hidden = self.rnn(inputs, rnn_hidden)
|
||||
# outputs = self.relu(rnn_hidden)
|
||||
outputs = self.dropout(inputs)
|
||||
outputs = self.linear(outputs)
|
||||
outputs = self.sigmoid(outputs)
|
||||
return outputs, rnn_hidden
|
||||
return outputs
|
4
setup.py
4
setup.py
|
@ -80,11 +80,11 @@ setup(
|
|||
"matplotlib==2.0.2",
|
||||
"Pillow",
|
||||
"flask",
|
||||
"lws",
|
||||
# "lws",
|
||||
"tqdm",
|
||||
],
|
||||
extras_require={
|
||||
"bin": [
|
||||
"tqdm",
|
||||
"requests",
|
||||
],
|
||||
})
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
import torch as T
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import load_config
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
INPUTPATH = os.path.join(file_path, 'inputs')
|
||||
OUTPATH = os.path.join(file_path, "outputs/audio_tests")
|
||||
os.makedirs(OUTPATH, exist_ok=True)
|
||||
|
||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
|
||||
|
||||
class TestAudio(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestAudio, self).__init__(*args, **kwargs)
|
||||
self.ap = AudioProcessor(**c.audio)
|
||||
|
||||
def test_audio_synthesis(self):
|
||||
""" 1. load wav
|
||||
2. set normalization parameters
|
||||
3. extract mel-spec
|
||||
4. invert to wav and save the output
|
||||
"""
|
||||
print(" > Sanity check for the process wav -> mel -> wav")
|
||||
|
||||
def _test(max_norm, signal_norm, symmetric_norm, clip_norm):
|
||||
self.ap.max_norm = max_norm
|
||||
self.ap.signal_norm = signal_norm
|
||||
self.ap.symmetric_norm = symmetric_norm
|
||||
self.ap.clip_norm = clip_norm
|
||||
wav = self.ap.load_wav(INPUTPATH + "/example_1.wav")
|
||||
mel = self.ap.melspectrogram(wav)
|
||||
wav_ = self.ap.inv_mel_spectrogram(mel)
|
||||
file_name = "/audio_test-melspec_max_norm_{}-signal_norm_{}-symmetric_{}-clip_norm_{}.wav"\
|
||||
.format(max_norm, signal_norm, symmetric_norm, clip_norm)
|
||||
print(" | > Creating wav file at : ", file_name)
|
||||
self.ap.save_wav(wav_, OUTPATH + file_name)
|
||||
|
||||
# maxnorm = 1.0
|
||||
_test(1., False, False, False)
|
||||
_test(1., True, False, False)
|
||||
_test(1., True, True, False)
|
||||
_test(1., True, False, True)
|
||||
_test(1., True, True, True)
|
||||
# maxnorm = 4.0
|
||||
_test(4., False, False, False)
|
||||
_test(4., True, False, False)
|
||||
_test(4., True, True, False)
|
||||
_test(4., True, False, True)
|
||||
_test(4., True, True, True)
|
||||
|
||||
def test_normalize(self):
|
||||
"""Check normalization and denormalization for range values and consistency """
|
||||
print(" > Testing normalization and denormalization.")
|
||||
wav = self.ap.load_wav(INPUTPATH + "/example_1.wav")
|
||||
self.ap.signal_norm = False
|
||||
x = self.ap.melspectrogram(wav)
|
||||
x_old = x
|
||||
|
||||
self.ap.signal_norm = True
|
||||
self.ap.symmetric_norm = False
|
||||
self.ap.clip_norm = False
|
||||
self.ap.max_norm = 4.0
|
||||
x_norm = self.ap._normalize(x)
|
||||
print(x_norm.max(), " -- ", x_norm.min())
|
||||
assert (x_old - x).sum() == 0
|
||||
# check value range
|
||||
assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
|
||||
assert x_norm.min() >= 0 - 1, x_norm.min()
|
||||
# check denorm.
|
||||
x_ = self.ap._denormalize(x_norm)
|
||||
assert (x - x_).sum() < 1e-3, (x - x_).mean()
|
||||
|
||||
self.ap.signal_norm = True
|
||||
self.ap.symmetric_norm = False
|
||||
self.ap.clip_norm = True
|
||||
self.ap.max_norm = 4.0
|
||||
x_norm = self.ap._normalize(x)
|
||||
print(x_norm.max(), " -- ", x_norm.min())
|
||||
assert (x_old - x).sum() == 0
|
||||
# check value range
|
||||
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
|
||||
assert x_norm.min() >= 0, x_norm.min()
|
||||
# check denorm.
|
||||
x_ = self.ap._denormalize(x_norm)
|
||||
assert (x - x_).sum() < 1e-3, (x - x_).mean()
|
||||
|
||||
self.ap.signal_norm = True
|
||||
self.ap.symmetric_norm = True
|
||||
self.ap.clip_norm = False
|
||||
self.ap.max_norm = 4.0
|
||||
x_norm = self.ap._normalize(x)
|
||||
print(x_norm.max(), " -- ", x_norm.min())
|
||||
assert (x_old - x).sum() == 0
|
||||
# check value range
|
||||
assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
|
||||
assert x_norm.min() >= -self.ap.max_norm - 2, x_norm.min()
|
||||
assert x_norm.min() <= 0, x_norm.min()
|
||||
# check denorm.
|
||||
x_ = self.ap._denormalize(x_norm)
|
||||
assert (x - x_).sum() < 1e-3, (x - x_).mean()
|
||||
|
||||
self.ap.signal_norm = True
|
||||
self.ap.symmetric_norm = True
|
||||
self.ap.clip_norm = True
|
||||
self.ap.max_norm = 4.0
|
||||
x_norm = self.ap._normalize(x)
|
||||
print(x_norm.max(), " -- ", x_norm.min())
|
||||
assert (x_old - x).sum() == 0
|
||||
# check value range
|
||||
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
|
||||
assert x_norm.min() >= -self.ap.max_norm, x_norm.min()
|
||||
assert x_norm.min() <= 0, x_norm.min()
|
||||
# check denorm.
|
||||
x_ = self.ap._denormalize(x_norm)
|
||||
assert (x - x_).sum() < 1e-3, (x - x_).mean()
|
||||
|
||||
self.ap.signal_norm = True
|
||||
self.ap.symmetric_norm = False
|
||||
self.ap.max_norm = 1.0
|
||||
x_norm = self.ap._normalize(x)
|
||||
print(x_norm.max(), " -- ", x_norm.min())
|
||||
assert (x_old - x).sum() == 0
|
||||
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
|
||||
assert x_norm.min() >= 0, x_norm.min()
|
||||
x_ = self.ap._denormalize(x_norm)
|
||||
assert (x - x_).sum() < 1e-3
|
||||
|
||||
self.ap.signal_norm = True
|
||||
self.ap.symmetric_norm = True
|
||||
self.ap.max_norm = 1.0
|
||||
x_norm = self.ap._normalize(x)
|
||||
print(x_norm.max(), " -- ", x_norm.min())
|
||||
assert (x_old - x).sum() == 0
|
||||
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
|
||||
assert x_norm.min() >= -self.ap.max_norm, x_norm.min()
|
||||
assert x_norm.min() < 0, x_norm.min()
|
||||
x_ = self.ap._denormalize(x_norm)
|
||||
assert (x - x_).sum() < 1e-3
|
|
@ -1,50 +1,49 @@
|
|||
import os
|
||||
import unittest
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.datasets import LJSpeech, Kusal
|
||||
from TTS.datasets import TTSDataset, TTSDatasetCached, TTSDatasetMemory
|
||||
from TTS.datasets.preprocess import ljspeech, tts_cache
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
OUTPATH = os.path.join(file_path, "outputs/loader_tests/")
|
||||
os.makedirs(OUTPATH, exist_ok=True)
|
||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
ok_kusal = os.path.exists(c.data_path_Kusal)
|
||||
ok_ljspeech = os.path.exists(c.data_path_LJSpeech)
|
||||
ok_ljspeech = os.path.exists(c.data_path)
|
||||
|
||||
|
||||
class TestLJSpeechDataset(unittest.TestCase):
|
||||
class TestTTSDataset(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
|
||||
super(TestTTSDataset, self).__init__(*args, **kwargs)
|
||||
self.max_loader_iter = 4
|
||||
self.ap = AudioProcessor(
|
||||
sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis)
|
||||
self.ap = AudioProcessor(**c.audio)
|
||||
|
||||
def _create_dataloader(self, batch_size, r, bgs):
|
||||
dataset = TTSDataset.MyDataset(
|
||||
c.data_path,
|
||||
'metadata.csv',
|
||||
r,
|
||||
c.text_cleaner,
|
||||
preprocessor=ljspeech,
|
||||
ap=self.ap,
|
||||
batch_group_size=bgs,
|
||||
min_seq_len=c.min_seq_len)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
return dataloader, dataset
|
||||
|
||||
def test_loader(self):
|
||||
if ok_ljspeech:
|
||||
dataset = LJSpeech.MyDataset(
|
||||
os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
dataloader, dataset = self._create_dataloader(2, c.r, 0)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
|
@ -63,29 +62,158 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert linear_input.shape[2] == self.ap.num_freq
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.num_mels
|
||||
assert mel_input.shape[2] == c.audio['num_mels']
|
||||
# check normalization ranges
|
||||
if self.ap.symmetric_norm:
|
||||
assert mel_input.max() <= self.ap.max_norm
|
||||
assert mel_input.min() >= -self.ap.max_norm
|
||||
assert mel_input.min() < 0
|
||||
else:
|
||||
assert mel_input.max() <= self.ap.max_norm
|
||||
assert mel_input.min() >= 0
|
||||
|
||||
def test_batch_group_shuffle(self):
|
||||
if ok_ljspeech:
|
||||
dataset = LJSpeech.MyDataset(
|
||||
os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
batch_group_size=16,
|
||||
min_seq_len=c.min_seq_len)
|
||||
dataloader, dataset = self._create_dataloader(2, c.r, 16)
|
||||
last_length = 0
|
||||
frames = dataset.items
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
avg_length = mel_lengths.numpy().mean()
|
||||
assert avg_length >= last_length
|
||||
dataloader.dataset.sort_items()
|
||||
assert frames[0] != dataloader.dataset.items[0]
|
||||
|
||||
frames = dataset.frames
|
||||
def test_padding_and_spec(self):
|
||||
if ok_ljspeech:
|
||||
dataloader, dataset = self._create_dataloader(1, 1, 0)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
# check mel_spec consistency
|
||||
wav = self.ap.load_wav(item_idx[0])
|
||||
mel = self.ap.melspectrogram(wav)
|
||||
mel_dl = mel_input[0].cpu().numpy()
|
||||
assert (
|
||||
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0
|
||||
|
||||
# check mel-spec correctness
|
||||
mel_spec = mel_input[0].cpu().numpy()
|
||||
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
|
||||
self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader.wav')
|
||||
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader.wav')
|
||||
|
||||
# check linear-spec
|
||||
linear_spec = linear_input[0].cpu().numpy()
|
||||
wav = self.ap.inv_spectrogram(linear_spec.T)
|
||||
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav')
|
||||
shutil.copy(item_idx[0], OUTPATH + '/linear_target_dataloader.wav')
|
||||
|
||||
# check the last time step to be zero padded
|
||||
assert linear_input[0, -1].sum() == 0
|
||||
assert linear_input[0, -2].sum() != 0
|
||||
assert mel_input[0, -1].sum() == 0
|
||||
assert mel_input[0, -2].sum() != 0
|
||||
assert stop_target[0, -1] == 1
|
||||
assert stop_target[0, -2] == 0
|
||||
assert stop_target.sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[0] == linear_input[0].shape[0]
|
||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
|
||||
# Test for batch size 2
|
||||
dataloader, dataset = self._create_dataloader(2, 1, 0)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
if mel_lengths[0] > mel_lengths[1]:
|
||||
idx = 0
|
||||
else:
|
||||
idx = 1
|
||||
|
||||
# check the first item in the batch
|
||||
assert linear_input[idx, -1].sum() == 0
|
||||
assert linear_input[idx, -2].sum() != 0, linear_input
|
||||
assert mel_input[idx, -1].sum() == 0
|
||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||
assert stop_target[idx, -1] == 1
|
||||
assert stop_target[idx, -2] == 0
|
||||
assert stop_target[idx].sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
assert mel_lengths[idx] == linear_input[idx].shape[0]
|
||||
|
||||
# check the second itme in the batch
|
||||
assert linear_input[1 - idx, -1].sum() == 0
|
||||
assert mel_input[1 - idx, -1].sum() == 0
|
||||
assert stop_target[1 - idx, -1] == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
|
||||
# check batch conditions
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
||||
|
||||
class TestTTSDatasetCached(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestTTSDatasetCached, self).__init__(*args, **kwargs)
|
||||
self.max_loader_iter = 4
|
||||
self.c = load_config(os.path.join(c.data_path_cache, 'config.json'))
|
||||
self.ap = AudioProcessor(**self.c.audio)
|
||||
|
||||
def _create_dataloader(self, batch_size, r, bgs):
|
||||
|
||||
dataset = TTSDatasetCached.MyDataset(
|
||||
c.data_path_cache,
|
||||
'tts_metadata.csv',
|
||||
r,
|
||||
c.text_cleaner,
|
||||
preprocessor=tts_cache,
|
||||
ap=self.ap,
|
||||
batch_group_size=bgs,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
return dataloader, dataset
|
||||
|
||||
def test_loader(self):
|
||||
if ok_ljspeech:
|
||||
dataloader, dataset = self._create_dataloader(2, c.r, 0)
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
|
@ -102,32 +230,21 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.num_mels
|
||||
dataloader.dataset.sort_frames()
|
||||
assert frames[0] != dataloader.dataset.frames[0]
|
||||
assert mel_input.shape[2] == c.audio['num_mels']
|
||||
|
||||
if self.ap.symmetric_norm:
|
||||
assert mel_input.max() <= self.ap.max_norm
|
||||
assert mel_input.min() >= -self.ap.max_norm
|
||||
assert mel_input.min() < 0
|
||||
else:
|
||||
assert mel_input.max() <= self.ap.max_norm
|
||||
assert mel_input.min() >= 0
|
||||
|
||||
def test_padding(self):
|
||||
def test_batch_group_shuffle(self):
|
||||
if ok_ljspeech:
|
||||
dataset = LJSpeech.MyDataset(
|
||||
os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
1,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
# Test for batch size 1
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
dataloader, dataset = self._create_dataloader(2, c.r, 16)
|
||||
frames = dataset.items
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
|
@ -139,11 +256,51 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.audio['num_mels']
|
||||
dataloader.dataset.sort_items()
|
||||
assert frames[0] != dataloader.dataset.items[0]
|
||||
|
||||
def test_padding_and_spec(self):
|
||||
if ok_ljspeech:
|
||||
dataloader, dataset = self._create_dataloader(1, 1, 0)
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
# check mel_spec consistency
|
||||
if item_idx[0].split('.')[-1] == 'npy':
|
||||
wav = np.load(item_idx[0])
|
||||
else:
|
||||
wav = self.ap.load_wav(item_idx[0])
|
||||
mel = self.ap.melspectrogram(wav)
|
||||
mel_dl = mel_input[0].cpu().numpy()
|
||||
assert (abs(mel.T).astype("float32") - abs(
|
||||
mel_dl[:-1])).sum() == 0, (
|
||||
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum()
|
||||
|
||||
# check mel-spec correctness
|
||||
mel_spec = mel_input[0].cpu().numpy()
|
||||
wav = self.ap.inv_mel_spectrogram(mel_spec.T)
|
||||
self.ap.save_wav(wav,
|
||||
OUTPATH + '/mel_inv_dataloader_cache.wav')
|
||||
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_cache.wav')
|
||||
|
||||
# check the last time step to be zero padded
|
||||
assert mel_input[0, -1].sum() == 0
|
||||
assert mel_input[0, -2].sum() != 0
|
||||
assert linear_input[0, -1].sum() == 0
|
||||
assert linear_input[0, -2].sum() != 0
|
||||
assert stop_target[0, -1] == 1
|
||||
assert stop_target[0, -2] == 0
|
||||
assert stop_target.sum() == 1
|
||||
|
@ -151,14 +308,7 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
|
||||
# Test for batch size 2
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
dataloader, dataset = self._create_dataloader(2, 1, 0)
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
|
@ -178,8 +328,6 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
# check the first item in the batch
|
||||
assert mel_input[idx, -1].sum() == 0
|
||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||
assert linear_input[idx, -1].sum() == 0
|
||||
assert linear_input[idx, -2].sum() != 0
|
||||
assert stop_target[idx, -1] == 1
|
||||
assert stop_target[idx, -2] == 0
|
||||
assert stop_target[idx].sum() == 1
|
||||
|
@ -188,151 +336,202 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
|
||||
# check the second itme in the batch
|
||||
assert mel_input[1 - idx, -1].sum() == 0
|
||||
assert linear_input[1 - idx, -1].sum() == 0
|
||||
assert stop_target[1 - idx, -1] == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
|
||||
# check batch conditions
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
||||
|
||||
class TestKusalDataset(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestKusalDataset, self).__init__(*args, **kwargs)
|
||||
self.max_loader_iter = 4
|
||||
self.ap = AudioProcessor(
|
||||
sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis)
|
||||
# class TestTTSDatasetMemory(unittest.TestCase):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super(TestTTSDatasetMemory, self).__init__(*args, **kwargs)
|
||||
# self.max_loader_iter = 4
|
||||
# self.c = load_config(os.path.join(c.data_path_cache, 'config.json'))
|
||||
# self.ap = AudioProcessor(**c.audio)
|
||||
|
||||
def test_loader(self):
|
||||
if ok_kusal:
|
||||
dataset = Kusal.MyDataset(
|
||||
os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
# def test_loader(self):
|
||||
# if ok_ljspeech:
|
||||
# dataset = TTSDatasetMemory.MyDataset(
|
||||
# c.data_path_cache,
|
||||
# 'tts_metadata.csv',
|
||||
# c.r,
|
||||
# c.text_cleaner,
|
||||
# preprocessor=tts_cache,
|
||||
# ap=self.ap,
|
||||
# min_seq_len=c.min_seq_len)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
# dataloader = DataLoader(
|
||||
# dataset,
|
||||
# batch_size=2,
|
||||
# shuffle=True,
|
||||
# collate_fn=dataset.collate_fn,
|
||||
# drop_last=True,
|
||||
# num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
# for i, data in enumerate(dataloader):
|
||||
# if i == self.max_loader_iter:
|
||||
# break
|
||||
# text_input = data[0]
|
||||
# text_lengths = data[1]
|
||||
# linear_input = data[2]
|
||||
# mel_input = data[3]
|
||||
# mel_lengths = data[4]
|
||||
# stop_target = data[5]
|
||||
# item_idx = data[6]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.num_mels
|
||||
# neg_values = text_input[text_input < 0]
|
||||
# check_count = len(neg_values)
|
||||
# assert check_count == 0, \
|
||||
# " !! Negative values in text_input: {}".format(check_count)
|
||||
# # check mel-spec shape
|
||||
# assert mel_input.shape[0] == c.batch_size
|
||||
# assert mel_input.shape[2] == c.audio['num_mels']
|
||||
# assert mel_input.max() <= self.ap.max_norm
|
||||
# # check data range
|
||||
# if self.ap.symmetric_norm:
|
||||
# assert mel_input.max() <= self.ap.max_norm
|
||||
# assert mel_input.min() >= -self.ap.max_norm
|
||||
# assert mel_input.min() < 0
|
||||
# else:
|
||||
# assert mel_input.max() <= self.ap.max_norm
|
||||
# assert mel_input.min() >= 0
|
||||
|
||||
def test_padding(self):
|
||||
if ok_kusal:
|
||||
dataset = Kusal.MyDataset(
|
||||
os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
1,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
# def test_batch_group_shuffle(self):
|
||||
# if ok_ljspeech:
|
||||
# dataset = TTSDatasetMemory.MyDataset(
|
||||
# c.data_path_cache,
|
||||
# 'tts_metadata.csv',
|
||||
# c.r,
|
||||
# c.text_cleaner,
|
||||
# preprocessor=ljspeech,
|
||||
# ap=self.ap,
|
||||
# batch_group_size=16,
|
||||
# min_seq_len=c.min_seq_len)
|
||||
|
||||
# Test for batch size 1
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
# dataloader = DataLoader(
|
||||
# dataset,
|
||||
# batch_size=2,
|
||||
# shuffle=True,
|
||||
# collate_fn=dataset.collate_fn,
|
||||
# drop_last=True,
|
||||
# num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
# frames = dataset.items
|
||||
# for i, data in enumerate(dataloader):
|
||||
# if i == self.max_loader_iter:
|
||||
# break
|
||||
# text_input = data[0]
|
||||
# text_lengths = data[1]
|
||||
# linear_input = data[2]
|
||||
# mel_input = data[3]
|
||||
# mel_lengths = data[4]
|
||||
# stop_target = data[5]
|
||||
# item_idx = data[6]
|
||||
|
||||
# check the last time step to be zero padded
|
||||
assert mel_input[0, -1].sum() == 0
|
||||
# assert mel_input[0, -2].sum() != 0
|
||||
assert linear_input[0, -1].sum() == 0
|
||||
# assert linear_input[0, -2].sum() != 0
|
||||
assert stop_target[0, -1] == 1
|
||||
assert stop_target[0, -2] == 0
|
||||
assert stop_target.sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
# neg_values = text_input[text_input < 0]
|
||||
# check_count = len(neg_values)
|
||||
# assert check_count == 0, \
|
||||
# " !! Negative values in text_input: {}".format(check_count)
|
||||
# assert mel_input.shape[0] == c.batch_size
|
||||
# assert mel_input.shape[2] == c.audio['num_mels']
|
||||
# dataloader.dataset.sort_items()
|
||||
# assert frames[0] != dataloader.dataset.items[0]
|
||||
|
||||
# Test for batch size 2
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
num_workers=c.num_loader_workers)
|
||||
# def test_padding_and_spec(self):
|
||||
# if ok_ljspeech:
|
||||
# dataset = TTSDatasetMemory.MyDataset(
|
||||
# c.data_path_cache,
|
||||
# 'tts_meta_data.csv',
|
||||
# 1,
|
||||
# c.text_cleaner,
|
||||
# preprocessor=ljspeech,
|
||||
# ap=self.ap,
|
||||
# min_seq_len=c.min_seq_len)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
# # Test for batch size 1
|
||||
# dataloader = DataLoader(
|
||||
# dataset,
|
||||
# batch_size=1,
|
||||
# shuffle=False,
|
||||
# collate_fn=dataset.collate_fn,
|
||||
# drop_last=True,
|
||||
# num_workers=c.num_loader_workers)
|
||||
|
||||
if mel_lengths[0] > mel_lengths[1]:
|
||||
idx = 0
|
||||
else:
|
||||
idx = 1
|
||||
# for i, data in enumerate(dataloader):
|
||||
# if i == self.max_loader_iter:
|
||||
# break
|
||||
# text_input = data[0]
|
||||
# text_lengths = data[1]
|
||||
# linear_input = data[2]
|
||||
# mel_input = data[3]
|
||||
# mel_lengths = data[4]
|
||||
# stop_target = data[5]
|
||||
# item_idx = data[6]
|
||||
|
||||
# check the first item in the batch
|
||||
assert mel_input[idx, -1].sum() == 0
|
||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||
assert linear_input[idx, -1].sum() == 0
|
||||
assert linear_input[idx, -2].sum() != 0
|
||||
assert stop_target[idx, -1] == 1
|
||||
assert stop_target[idx, -2] == 0
|
||||
assert stop_target[idx].sum() == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
# # check mel_spec consistency
|
||||
# if item_idx[0].split('.')[-1] == 'npy':
|
||||
# wav = np.load(item_idx[0])
|
||||
# else:
|
||||
# wav = self.ap.load_wav(item_idx[0])
|
||||
# mel = self.ap.melspectrogram(wav)
|
||||
# mel_dl = mel_input[0].cpu().numpy()
|
||||
# assert (
|
||||
# abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0
|
||||
|
||||
# check the second itme in the batch
|
||||
assert mel_input[1 - idx, -1].sum() == 0
|
||||
assert linear_input[1 - idx, -1].sum() == 0
|
||||
assert stop_target[1 - idx, -1] == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
# # check mel-spec correctness
|
||||
# mel_spec = mel_input[0].cpu().numpy()
|
||||
# wav = self.ap.inv_mel_spectrogram(mel_spec.T)
|
||||
# self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader_memo.wav')
|
||||
# shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader_memo.wav')
|
||||
|
||||
# check batch conditions
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
# # check the last time step to be zero padded
|
||||
# assert mel_input[0, -1].sum() == 0
|
||||
# assert mel_input[0, -2].sum() != 0
|
||||
# assert stop_target[0, -1] == 1
|
||||
# assert stop_target[0, -2] == 0
|
||||
# assert stop_target.sum() == 1
|
||||
# assert len(mel_lengths.shape) == 1
|
||||
# assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
|
||||
# # Test for batch size 2
|
||||
# dataloader = DataLoader(
|
||||
# dataset,
|
||||
# batch_size=2,
|
||||
# shuffle=False,
|
||||
# collate_fn=dataset.collate_fn,
|
||||
# drop_last=False,
|
||||
# num_workers=c.num_loader_workers)
|
||||
|
||||
# for i, data in enumerate(dataloader):
|
||||
# if i == self.max_loader_iter:
|
||||
# break
|
||||
# text_input = data[0]
|
||||
# text_lengths = data[1]
|
||||
# linear_input = data[2]
|
||||
# mel_input = data[3]
|
||||
# mel_lengths = data[4]
|
||||
# stop_target = data[5]
|
||||
# item_idx = data[6]
|
||||
|
||||
# if mel_lengths[0] > mel_lengths[1]:
|
||||
# idx = 0
|
||||
# else:
|
||||
# idx = 1
|
||||
|
||||
# # check the first item in the batch
|
||||
# assert mel_input[idx, -1].sum() == 0
|
||||
# assert mel_input[idx, -2].sum() != 0, mel_input
|
||||
# assert stop_target[idx, -1] == 1
|
||||
# assert stop_target[idx, -2] == 0
|
||||
# assert stop_target[idx].sum() == 1
|
||||
# assert len(mel_lengths.shape) == 1
|
||||
# assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
|
||||
# # check the second itme in the batch
|
||||
# assert mel_input[1 - idx, -1].sum() == 0
|
||||
# assert stop_target[1 - idx, -1] == 1
|
||||
# assert len(mel_lengths.shape) == 1
|
||||
|
||||
# # check batch conditions
|
||||
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
|
|
@ -21,8 +21,8 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
|
|||
class TacotronTrainTest(unittest.TestCase):
|
||||
def test_train_step(self):
|
||||
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
mel_spec = torch.rand(8, 30, c.num_mels).to(device)
|
||||
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
|
||||
|
@ -35,7 +35,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
|
||||
criterion = L1LossMasked().to(device)
|
||||
criterion_st = nn.BCELoss().to(device)
|
||||
model = Tacotron(c.embedding_size, c.num_freq, c.num_mels,
|
||||
model = Tacotron(c.embedding_size, c.audio['num_freq'], c.audio['num_mels'],
|
||||
c.r).to(device)
|
||||
model.train()
|
||||
model_ref = copy.deepcopy(model)
|
||||
|
|
|
@ -1,16 +1,25 @@
|
|||
{
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"sample_rate": 22050,
|
||||
"frame_length_ms": 50,
|
||||
"frame_shift_ms": 12.5,
|
||||
"preemphasis": 0.97,
|
||||
"min_level_db": -100,
|
||||
"ref_level_db": 20,
|
||||
"audio":{
|
||||
"audio_processor": "audio", // to use dictate different audio processors, if available.
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled.
|
||||
"frame_length_ms": 50, // stft window length in ms.
|
||||
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
|
||||
"preemphasis": 0.97, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"min_level_db": -100, // normalization range
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 30,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"mel_fmin": 95, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 7600 // maximum freq level for mel-spec. Tune for dataset!!
|
||||
},
|
||||
"hidden_size": 128,
|
||||
"embedding_size": 256,
|
||||
"min_mel_freq": null,
|
||||
"max_mel_freq": null,
|
||||
"text_cleaner": "english_cleaners",
|
||||
|
||||
"epochs": 2000,
|
||||
|
@ -21,16 +30,11 @@
|
|||
"r": 5,
|
||||
"mk": 1.0,
|
||||
"priority_freq": false,
|
||||
|
||||
|
||||
"griffin_lim_iters": 60,
|
||||
"power": 1.5,
|
||||
|
||||
"num_loader_workers": 4,
|
||||
|
||||
"save_step": 200,
|
||||
"data_path_LJSpeech": "/home/erogol/Data/LJSpeech-1.1",
|
||||
"data_path_Kusal": "/home/erogol/Data/Kusal",
|
||||
"data_path": "/home/erogol/Data/LJSpeech-1.1/",
|
||||
"data_path_cache": "/home/erogol/Data/LJSpeech-1.1/tts_cache/",
|
||||
"output_path": "result",
|
||||
"min_seq_len": 0,
|
||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||
|
|
54
train.py
54
train.py
|
@ -14,16 +14,16 @@ from torch.utils.data import DataLoader
|
|||
from tensorboardX import SummaryWriter
|
||||
|
||||
from utils.generic_utils import (
|
||||
synthesis, remove_experiment_folder, create_experiment_folder,
|
||||
remove_experiment_folder, create_experiment_folder,
|
||||
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
|
||||
check_update, get_commit_hash, sequence_mask, AnnealLR)
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from models.tacotron import Tacotron
|
||||
from layers.losses import L1LossMasked
|
||||
from utils.audio import AudioProcessor
|
||||
from utils.synthesis import synthesis
|
||||
|
||||
torch.manual_seed(1)
|
||||
# torch.set_num_threads(4)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
|
||||
|
@ -36,7 +36,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
avg_stop_loss = 0
|
||||
avg_step_time = 0
|
||||
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||
n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
@ -215,7 +215,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist."
|
||||
]
|
||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||
n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
|
||||
with torch.no_grad():
|
||||
if data_loader is not None:
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
|
@ -277,6 +277,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
const_spec = linear_output[idx].data.cpu().numpy()
|
||||
gt_spec = linear_input[idx].data.cpu().numpy()
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
|
||||
const_spec = plot_spectrogram(const_spec, ap)
|
||||
gt_spec = plot_spectrogram(gt_spec, ap)
|
||||
align_img = plot_alignment(align_img)
|
||||
|
@ -319,8 +320,8 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
ap.griffin_lim_iters = 60
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
|
||||
use_cuda, c.text_cleaner)
|
||||
wav, alignment, linear_spec, stop_tokens = synthesis(model, test_sentence, c,
|
||||
use_cuda, ap)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(current_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
|
@ -330,7 +331,7 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
|
||||
wav_name = 'TestSentences/{}'.format(idx)
|
||||
tb.add_audio(
|
||||
wav_name, wav, current_step, sample_rate=c.sample_rate)
|
||||
wav_name, wav, current_step, sample_rate=c.audio['sample_rate'])
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
linear_spec = plot_spectrogram(linear_spec, ap)
|
||||
align_img = plot_alignment(align_img)
|
||||
|
@ -345,28 +346,22 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
|
||||
|
||||
def main(args):
|
||||
dataset = importlib.import_module('datasets.' + c.dataset)
|
||||
Dataset = getattr(dataset, 'MyDataset')
|
||||
audio = importlib.import_module('utils.' + c.audio_processor)
|
||||
preprocessor = importlib.import_module('datasets.preprocess')
|
||||
preprocessor = getattr(preprocessor, c.dataset.lower())
|
||||
MyDataset = importlib.import_module('datasets.'+c.data_loader)
|
||||
MyDataset = getattr(MyDataset, "MyDataset")
|
||||
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
|
||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||
|
||||
ap = AudioProcessor(
|
||||
sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis)
|
||||
ap = AudioProcessor(**c.audio)
|
||||
|
||||
# Setup the dataset
|
||||
train_dataset = Dataset(
|
||||
train_dataset = MyDataset(
|
||||
c.data_path,
|
||||
c.meta_file_train,
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
preprocessor=preprocessor,
|
||||
ap=ap,
|
||||
batch_group_size=8*c.batch_size,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
@ -381,8 +376,8 @@ def main(args):
|
|||
pin_memory=True)
|
||||
|
||||
if c.run_eval:
|
||||
val_dataset = Dataset(
|
||||
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap, batch_group_size=0)
|
||||
val_dataset = MyDataset(
|
||||
c.data_path, c.meta_file_val, c.r, c.text_cleaner, preprocessor=preprocessor, ap=ap, batch_group_size=0)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
|
@ -395,7 +390,7 @@ def main(args):
|
|||
else:
|
||||
val_loader = None
|
||||
|
||||
model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
|
||||
model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r)
|
||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
@ -431,7 +426,7 @@ def main(args):
|
|||
criterion.cuda()
|
||||
criterion_st.cuda()
|
||||
|
||||
scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps, last_epoch= (args.restore_step-1))
|
||||
scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
num_params = count_parameters(model)
|
||||
print(" | > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
|
@ -454,7 +449,7 @@ def main(args):
|
|||
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
|
||||
OUT_PATH, current_step, epoch)
|
||||
# shuffle batch groups
|
||||
train_loader.dataset.sort_frames()
|
||||
train_loader.dataset.sort_items()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -477,8 +472,9 @@ if __name__ == '__main__':
|
|||
parser.add_argument(
|
||||
'--data_path',
|
||||
type=str,
|
||||
default='',
|
||||
help='data path to overrite config.json')
|
||||
help='dataset path.',
|
||||
default=''
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# setup output paths and read configs
|
||||
|
@ -491,7 +487,7 @@ if __name__ == '__main__':
|
|||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
||||
|
||||
if args.data_path != "":
|
||||
if args.data_path != '':
|
||||
c.data_path = args.data_path
|
||||
|
||||
# setup tensorboard
|
||||
|
|
169
utils/audio.py
169
utils/audio.py
|
@ -3,26 +3,34 @@ import librosa
|
|||
import pickle
|
||||
import copy
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy import signal
|
||||
|
||||
_mel_basis = None
|
||||
from pprint import pprint
|
||||
from scipy import signal, io
|
||||
|
||||
|
||||
class AudioProcessor(object):
|
||||
def __init__(self,
|
||||
sample_rate,
|
||||
num_mels,
|
||||
min_level_db,
|
||||
frame_shift_ms,
|
||||
frame_length_ms,
|
||||
ref_level_db,
|
||||
num_freq,
|
||||
power,
|
||||
preemphasis,
|
||||
griffin_lim_iters=None):
|
||||
bits=None,
|
||||
sample_rate=None,
|
||||
num_mels=None,
|
||||
min_level_db=None,
|
||||
frame_shift_ms=None,
|
||||
frame_length_ms=None,
|
||||
ref_level_db=None,
|
||||
num_freq=None,
|
||||
power=None,
|
||||
preemphasis=None,
|
||||
signal_norm=None,
|
||||
symmetric_norm=None,
|
||||
max_norm=None,
|
||||
mel_fmin=None,
|
||||
mel_fmax=None,
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=None,
|
||||
**kwargs):
|
||||
|
||||
print(" > Setting up Audio Processor...")
|
||||
|
||||
self.bits = bits
|
||||
self.sample_rate = sample_rate
|
||||
self.num_mels = num_mels
|
||||
self.min_level_db = min_level_db
|
||||
|
@ -33,33 +41,79 @@ class AudioProcessor(object):
|
|||
self.power = power
|
||||
self.preemphasis = preemphasis
|
||||
self.griffin_lim_iters = griffin_lim_iters
|
||||
self.signal_norm = signal_norm
|
||||
self.symmetric_norm = symmetric_norm
|
||||
self.mel_fmin = 0 if mel_fmin is None else mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.max_norm = 1.0 if max_norm is None else float(max_norm)
|
||||
self.clip_norm = clip_norm
|
||||
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
||||
if preemphasis == 0:
|
||||
print(" | > Preemphasis is deactive.")
|
||||
print(" | > Audio Processor attributes.")
|
||||
members = vars(self)
|
||||
pprint(members)
|
||||
|
||||
def save_wav(self, wav, path):
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
# librosa.output.write_wav(path, wav_norm.astype(np.int16), self.sample_rate)
|
||||
scipy.io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16))
|
||||
io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16))
|
||||
|
||||
def _linear_to_mel(self, spectrogram):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = self._build_mel_basis()
|
||||
_mel_basis = self._build_mel_basis()
|
||||
return np.dot(_mel_basis, spectrogram)
|
||||
|
||||
def _mel_to_linear(self, mel_spec):
|
||||
inv_mel_basis = np.linalg.pinv(self._build_mel_basis())
|
||||
return np.maximum(1e-10, np.dot(inv_mel_basis, mel_spec))
|
||||
|
||||
def _build_mel_basis(self, ):
|
||||
n_fft = (self.num_freq - 1) * 2
|
||||
if self.mel_fmax is not None:
|
||||
assert self.mel_fmax <= self.sample_rate // 2
|
||||
return librosa.filters.mel(
|
||||
self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||
self.sample_rate,
|
||||
n_fft,
|
||||
n_mels=self.num_mels,
|
||||
fmin=self.mel_fmin,
|
||||
fmax=self.mel_fmax)
|
||||
|
||||
def _normalize(self, S):
|
||||
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||
"""Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]"""
|
||||
if self.signal_norm:
|
||||
S_norm = ((S - self.min_level_db) / - self.min_level_db)
|
||||
if self.symmetric_norm:
|
||||
S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm
|
||||
if self.clip_norm :
|
||||
S_norm = np.clip(S_norm, -self.max_norm, self.max_norm)
|
||||
return S_norm
|
||||
else:
|
||||
S_norm = self.max_norm * S_norm
|
||||
if self.clip_norm:
|
||||
S_norm = np.clip(S_norm, 0, self.max_norm)
|
||||
return S_norm
|
||||
else:
|
||||
return S
|
||||
|
||||
def _denormalize(self, S):
|
||||
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
||||
"""denormalize values"""
|
||||
S_denorm = S
|
||||
if self.signal_norm:
|
||||
if self.symmetric_norm:
|
||||
if self.clip_norm:
|
||||
S_denorm = np.clip(S_denorm, -self.max_norm, self.max_norm)
|
||||
S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db
|
||||
return S_denorm
|
||||
else:
|
||||
if self.clip_norm:
|
||||
S_denorm = np.clip(S_denorm, 0, self.max_norm)
|
||||
S_denorm = (S_denorm * -self.min_level_db /
|
||||
self.max_norm) + self.min_level_db
|
||||
return S_denorm
|
||||
else:
|
||||
return S
|
||||
|
||||
def _stft_parameters(self, ):
|
||||
"""Compute necessary stft parameters with given time values"""
|
||||
n_fft = (self.num_freq - 1) * 2
|
||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
||||
|
@ -92,8 +146,16 @@ class AudioProcessor(object):
|
|||
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
||||
return self._normalize(S)
|
||||
|
||||
def melspectrogram(self, y):
|
||||
if self.preemphasis != 0:
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
else:
|
||||
D = self._stft(y)
|
||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
|
||||
return self._normalize(S)
|
||||
|
||||
def inv_spectrogram(self, spectrogram):
|
||||
'''Converts spectrogram to waveform using librosa'''
|
||||
"""Converts spectrogram to waveform using librosa"""
|
||||
S = self._denormalize(spectrogram)
|
||||
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||
# Reconstruct phase
|
||||
|
@ -102,6 +164,16 @@ class AudioProcessor(object):
|
|||
else:
|
||||
return self._griffin_lim(S**self.power)
|
||||
|
||||
def inv_mel_spectrogram(self, mel_spectrogram):
|
||||
'''Converts mel spectrogram to waveform using librosa'''
|
||||
D = self._denormalize(mel_spectrogram)
|
||||
S = self._db_to_amp(D + self.ref_level_db)
|
||||
S = self._mel_to_linear(S) # Convert back to linear
|
||||
if self.preemphasis != 0:
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||
else:
|
||||
return self._griffin_lim(S**self.power)
|
||||
|
||||
def _griffin_lim(self, S):
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
S_complex = np.abs(S).astype(np.complex)
|
||||
|
@ -111,20 +183,17 @@ class AudioProcessor(object):
|
|||
y = self._istft(S_complex * angles)
|
||||
return y
|
||||
|
||||
def melspectrogram(self, y):
|
||||
if self.preemphasis != 0:
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
else:
|
||||
D = self._stft(y)
|
||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
|
||||
return self._normalize(S)
|
||||
|
||||
def _stft(self, y):
|
||||
return librosa.stft(
|
||||
y=y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length)
|
||||
y=y,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
)
|
||||
|
||||
def _istft(self, y):
|
||||
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
|
||||
return librosa.istft(
|
||||
y, hop_length=self.hop_length, win_length=self.win_length)
|
||||
|
||||
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||
window_length = int(self.sample_rate * min_silence_sec)
|
||||
|
@ -134,3 +203,37 @@ class AudioProcessor(object):
|
|||
if np.max(wav[x:x + window_length]) < threshold:
|
||||
return x + hop_length
|
||||
return len(wav)
|
||||
|
||||
# WaveRNN repo specific functions
|
||||
# def mulaw_encode(self, wav, qc):
|
||||
# mu = qc - 1
|
||||
# wav_abs = np.minimum(np.abs(wav), 1.0)
|
||||
# magnitude = np.log(1 + mu * wav_abs) / np.log(1. + mu)
|
||||
# signal = np.sign(wav) * magnitude
|
||||
# # Quantize signal to the specified number of levels.
|
||||
# signal = (signal + 1) / 2 * mu + 0.5
|
||||
# return signal.astype(np.int32)
|
||||
|
||||
# def mulaw_decode(self, wav, qc):
|
||||
# """Recovers waveform from quantized values."""
|
||||
# mu = qc - 1
|
||||
# # Map values back to [-1, 1].
|
||||
# casted = wav.astype(np.float32)
|
||||
# signal = 2 * (casted / mu) - 1
|
||||
# # Perform inverse of mu-law transformation.
|
||||
# magnitude = (1 / mu) * ((1 + mu) ** abs(signal) - 1)
|
||||
# return np.sign(signal) * magnitude
|
||||
|
||||
def load_wav(self, filename, encode=False):
|
||||
x, sr = librosa.load(filename, sr=self.sample_rate)
|
||||
assert self.sample_rate == sr
|
||||
return x
|
||||
|
||||
def encode_16bits(self, x):
|
||||
return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
|
||||
|
||||
def quantize(self, x):
|
||||
return (x + 1.) * (2**self.bits - 1) / 2
|
||||
|
||||
def dequantize(self, x):
|
||||
return 2 * x / (2**self.bits - 1) - 1
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import glob
|
||||
import time
|
||||
|
@ -21,18 +22,23 @@ class AttrDict(dict):
|
|||
|
||||
def load_config(config_path):
|
||||
config = AttrDict()
|
||||
config.update(json.load(open(config_path, "r")))
|
||||
with open(config_path, "r") as f:
|
||||
input_str = f.read()
|
||||
input_str = re.sub(r'\\\n', '', input_str)
|
||||
input_str = re.sub(r'//.*\n', '\n', input_str)
|
||||
data = json.loads(input_str)
|
||||
config.update(data)
|
||||
return config
|
||||
|
||||
|
||||
def get_commit_hash():
|
||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
||||
try:
|
||||
subprocess.check_output(['git', 'diff-index', '--quiet',
|
||||
'HEAD']) # Verify client is clean
|
||||
except:
|
||||
raise RuntimeError(
|
||||
" !! Commit before training to get the commit hash.")
|
||||
# try:
|
||||
# subprocess.check_output(['git', 'diff-index', '--quiet',
|
||||
# 'HEAD']) # Verify client is clean
|
||||
# except:
|
||||
# raise RuntimeError(
|
||||
# " !! Commit before training to get the commit hash.")
|
||||
commit = subprocess.check_output(['git', 'rev-parse', '--short',
|
||||
'HEAD']).decode().strip()
|
||||
print(' > Git Hash: {}'.format(commit))
|
||||
|
@ -177,15 +183,3 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
seq_length_expand = (sequence_length.unsqueeze(1)
|
||||
.expand_as(seq_range_expand))
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
def synthesis(model, ap, text, use_cuda, text_cleaner):
|
||||
text_cleaner = [text_cleaner]
|
||||
seq = np.array(text_to_sequence(text, text_cleaner))
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
if use_cuda:
|
||||
chars_var = chars_var.cuda().long()
|
||||
_, linear_out, alignments, _ = model.forward(chars_var)
|
||||
linear_out = linear_out[0].data.cpu().numpy()
|
||||
wav = ap.inv_spectrogram(linear_out.T)
|
||||
return wav, linear_out, alignments
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
import io
|
||||
import time
|
||||
import librosa
|
||||
import torch
|
||||
import numpy as np
|
||||
from .text import text_to_sequence
|
||||
from .visual import visualize
|
||||
from matplotlib import pylab as plt
|
||||
|
||||
|
||||
def synthesis(m, s, CONFIG, use_cuda, ap):
|
||||
""" Given the text, synthesising the audio """
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
seq = np.array(text_to_sequence(s, text_cleaner))
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
if use_cuda:
|
||||
chars_var = chars_var.cuda()
|
||||
mel_spec, linear_spec, alignments, stop_tokens = m.forward(chars_var.long())
|
||||
linear_spec = linear_spec[0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
wav = ap.inv_spectrogram(linear_spec.T)
|
||||
# wav = wav[:ap.find_endpoint(wav)]
|
||||
return wav, alignment, linear_spec, stop_tokens
|
|
@ -1,4 +1,5 @@
|
|||
import numpy as np
|
||||
import librosa
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -14,6 +15,7 @@ def plot_alignment(alignment, info=None):
|
|||
xlabel += '\n\n' + info
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel('Encoder timestep')
|
||||
# plt.yticks(range(len(text)), list(text))
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
|
@ -25,3 +27,28 @@ def plot_spectrogram(linear_output, audio):
|
|||
plt.colorbar()
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
|
||||
def visualize(alignment, spectrogram, stop_tokens, text, hop_length, CONFIG):
|
||||
label_fontsize = 16
|
||||
plt.figure(figsize=(16, 32))
|
||||
|
||||
plt.subplot(3, 1, 1)
|
||||
plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None)
|
||||
plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
|
||||
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
|
||||
plt.yticks(range(len(text)), list(text))
|
||||
plt.colorbar()
|
||||
|
||||
stop_tokens = stop_tokens.squeeze().detach().to('cpu').numpy()
|
||||
plt.subplot(3, 1, 2)
|
||||
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
||||
|
||||
plt.subplot(3, 1, 3)
|
||||
librosa.display.specshow(spectrogram.T, sr=CONFIG.audio['sample_rate'],
|
||||
hop_length=hop_length, x_axis="time", y_axis="linear")
|
||||
plt.xlabel("Time", fontsize=label_fontsize)
|
||||
plt.ylabel("Hz", fontsize=label_fontsize)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.colorbar()
|
||||
|
|
Loading…
Reference in New Issue