TTS/datasets/TTSDataset.py

244 lines
9.7 KiB
Python
Raw Normal View History

2018-01-22 09:48:59 +00:00
import os
import numpy as np
import collections
import librosa
import torch
2018-09-20 09:08:12 +00:00
import random
2018-01-22 09:48:59 +00:00
from torch.utils.data import Dataset
from utils.text import text_to_sequence, phoneme_to_sequence
2018-08-02 14:34:17 +00:00
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
2018-01-22 09:48:59 +00:00
2018-07-25 17:14:07 +00:00
class MyDataset(Dataset):
2018-08-02 14:34:17 +00:00
def __init__(self,
2018-11-02 15:13:51 +00:00
root_path,
meta_file,
2018-08-02 14:34:17 +00:00
outputs_per_step,
text_cleaner,
ap,
2018-11-02 15:13:51 +00:00
preprocessor,
2018-09-20 09:08:12 +00:00
batch_group_size=0,
min_seq_len=0,
max_seq_len=float("inf"),
use_phonemes=True,
phoneme_cache_path=None,
phoneme_language="en-us",
2019-04-12 14:12:15 +00:00
enable_eos_bos=False,
verbose=False):
"""
Args:
root_path (str): root path for the data folder.
meta_file (str): name for dataset file including audio transcripts
and file names (or paths in cached mode).
outputs_per_step (int): number of time frames predicted per step.
text_cleaner (str): text cleaner used for the dataset.
ap (TTS.utils.AudioProcessor): audio processor object.
preprocessor (dataset.preprocess.Class): preprocessor for the dataset.
Create your own if you need to run a new dataset.
batch_group_size (int): (0) range of batch randomization after sorting
sequences by length.
min_seq_len (int): (0) minimum sequence length to be processed
by the loader.
max_seq_len (int): (float("inf")) maximum sequence length.
cached (bool): (false) true if the given data path is created
by extract_features.py.
use_phonemes (bool): (true) if true, text converted to phonemes.
2019-01-15 14:51:13 +00:00
phoneme_cache_path (str): path to cache phoneme features.
phoneme_language (str): one the languages from
https://github.com/bootphon/phonemizer#languages
2019-04-12 14:12:15 +00:00
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
verbose (bool): print diagnostic information.
"""
2018-11-02 15:13:51 +00:00
self.root_path = root_path
2018-09-20 09:08:12 +00:00
self.batch_group_size = batch_group_size
2018-11-02 15:13:51 +00:00
self.items = preprocessor(root_path, meta_file)
2018-01-22 09:48:59 +00:00
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
2018-03-09 17:51:32 +00:00
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
self.use_phonemes = use_phonemes
2019-01-15 14:51:13 +00:00
self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language
2019-04-12 14:12:15 +00:00
self.enable_eos_bos = enable_eos_bos
self.verbose = verbose
2019-02-25 17:34:06 +00:00
if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
if self.verbose:
print("\n > DataLoader initialization")
print(" | > Data path: {}".format(root_path))
print(" | > Use phonemes: {}".format(self.use_phonemes))
if use_phonemes:
print(" | > phoneme language: {}".format(phoneme_language))
print(" | > Cached dataset: {}".format(self.cached))
print(" | > Number of instances : {}".format(len(self.items)))
2018-11-02 15:13:51 +00:00
self.sort_items()
2018-01-22 09:48:59 +00:00
def load_wav(self, filename):
try:
2018-11-02 15:13:51 +00:00
audio = self.ap.load_wav(filename)
2018-01-22 09:48:59 +00:00
return audio
except:
2018-01-22 09:48:59 +00:00
print(" !! Cannot read file : {}".format(filename))
def load_np(self, filename):
data = np.load(filename).astype('float32')
return data
def load_phoneme_sequence(self, wav_file, text):
file_name = os.path.basename(wav_file).split('.')[0]
tmp_path = os.path.join(self.phoneme_cache_path,
file_name + '_phoneme.npy')
if os.path.isfile(tmp_path):
try:
text = np.load(tmp_path)
except:
print(" > ERROR: phoneme connot be loaded for {}. Recomputing.".format(wav_file))
text = np.asarray(
phoneme_to_sequence(
2019-04-12 14:12:15 +00:00
text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos),
dtype=np.int32)
np.save(tmp_path, text)
else:
text = np.asarray(
phoneme_to_sequence(
2019-04-12 14:12:15 +00:00
text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos),
dtype=np.int32)
np.save(tmp_path, text)
return text
def load_data(self, idx):
2019-04-29 09:07:04 +00:00
text, wav_file = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = None
linear = None
if self.use_phonemes:
text = self.load_phoneme_sequence(wav_file, text)
else:
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
2019-04-18 14:25:04 +00:00
assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1]
sample = {
'text': text,
'wav': wav,
'item_idx': self.items[idx][1],
'mel': mel,
'linear': linear
}
return sample
2018-11-02 15:13:51 +00:00
def sort_items(self):
r"""Sort instances based on text length in ascending order"""
2018-11-02 15:13:51 +00:00
lengths = np.array([len(ins[0]) for ins in self.items])
2018-03-07 14:58:51 +00:00
idxs = np.argsort(lengths)
2018-11-02 15:13:51 +00:00
new_items = []
2018-03-09 17:46:47 +00:00
ignored = []
2018-03-07 14:58:51 +00:00
for i, idx in enumerate(idxs):
2018-03-09 17:46:47 +00:00
length = lengths[idx]
if length < self.min_seq_len or length > self.max_seq_len:
2018-03-09 17:46:47 +00:00
ignored.append(idx)
2018-03-09 17:49:18 +00:00
else:
2018-11-02 15:13:51 +00:00
new_items.append(self.items[idx])
2018-09-20 09:08:12 +00:00
# shuffle batch groups
if self.batch_group_size > 0:
2018-11-02 15:13:51 +00:00
for i in range(len(new_items) // self.batch_group_size):
2018-09-20 09:08:12 +00:00
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_items = new_items[offset:end_offset]
2018-11-02 15:13:51 +00:00
random.shuffle(temp_items)
new_items[offset:end_offset] = temp_items
2018-11-02 15:13:51 +00:00
self.items = new_items
2018-04-03 10:24:57 +00:00
if self.verbose:
print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths)))
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
print(" | > Num. instances discarded by max-min seq limits: {}".format(
len(ignored), self.min_seq_len))
print(" | > Batch group size: {}.".format(self.batch_group_size))
2018-01-22 09:48:59 +00:00
def __len__(self):
2018-11-02 15:13:51 +00:00
return len(self.items)
2018-01-22 09:48:59 +00:00
def __getitem__(self, idx):
return self.load_data(idx)
2018-01-22 09:48:59 +00:00
def collate_fn(self, batch):
2018-03-07 14:58:51 +00:00
r"""
Perform preprocessing and create a final data batch:
1. PAD sequences with the longest sequence in the batch
2. Convert Audio signal to Spectrograms.
3. PAD sequences that can be divided by r.
4. Convert Numpy to Torch tensors.
"""
2018-01-22 09:48:59 +00:00
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
text_lenghts = np.array([len(d["text"]) for d in batch])
text_lenghts, ids_sorted_decreasing = torch.sort(
torch.LongTensor(text_lenghts), dim=0, descending=True)
wav = [batch[idx]['wav'] for idx in ids_sorted_decreasing]
item_idxs = [
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
]
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
2018-01-22 09:48:59 +00:00
# if specs are not computed, compute them.
if batch[0]['mel'] is None and batch[0]['linear'] is None:
mel = [
self.ap.melspectrogram(w).astype('float32') for w in wav
]
linear = [
self.ap.spectrogram(w).astype('float32') for w in wav
]
else:
mel = [d['mel'] for d in batch]
linear = [d['linear'] for d in batch]
2018-04-03 10:24:57 +00:00
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
2018-03-22 21:06:33 +00:00
# compute 'stop token' targets
2018-08-02 14:34:17 +00:00
stop_targets = [
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
2018-04-03 10:24:57 +00:00
2018-03-22 21:06:33 +00:00
# PAD stop targets
2018-08-02 14:34:17 +00:00
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
2018-03-22 19:34:16 +00:00
2018-01-22 09:48:59 +00:00
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
# PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
2018-04-03 10:24:57 +00:00
timesteps = mel.shape[2]
2018-01-22 09:48:59 +00:00
2018-03-22 20:46:52 +00:00
# B x T x D
linear = linear.transpose(0, 2, 1)
2018-01-22 14:58:12 +00:00
mel = mel.transpose(0, 2, 1)
2018-02-09 13:39:58 +00:00
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
2018-11-20 13:56:19 +00:00
linear = torch.FloatTensor(linear).contiguous()
mel = torch.FloatTensor(mel).contiguous()
2018-03-22 20:46:52 +00:00
mel_lengths = torch.LongTensor(mel_lengths)
2018-03-22 19:34:16 +00:00
stop_targets = torch.FloatTensor(stop_targets)
2018-04-03 10:24:57 +00:00
2018-10-04 11:53:39 +00:00
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
2018-01-22 09:48:59 +00:00
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
2018-08-02 14:34:17 +00:00
found {}".format(type(batch[0]))))