TTS/datasets/TTSDataset.py

143 lines
5.3 KiB
Python
Raw Normal View History

2018-01-22 09:48:59 +00:00
import os
import numpy as np
import collections
import librosa
import torch
2018-09-20 09:08:12 +00:00
import random
2018-01-22 09:48:59 +00:00
from torch.utils.data import Dataset
2018-06-21 14:33:30 +00:00
from utils.text import text_to_sequence
2018-08-02 14:34:17 +00:00
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
2018-01-22 09:48:59 +00:00
2018-07-25 17:14:07 +00:00
class MyDataset(Dataset):
2018-08-02 14:34:17 +00:00
def __init__(self,
2018-11-02 15:13:51 +00:00
root_path,
meta_file,
2018-08-02 14:34:17 +00:00
outputs_per_step,
text_cleaner,
ap,
2018-11-02 15:13:51 +00:00
preprocessor,
2018-09-20 09:08:12 +00:00
batch_group_size=0,
2018-08-02 14:34:17 +00:00
min_seq_len=0):
2018-11-02 15:13:51 +00:00
self.root_path = root_path
2018-09-20 09:08:12 +00:00
self.batch_group_size = batch_group_size
2018-11-02 15:13:51 +00:00
self.items = preprocessor(root_path, meta_file)
2018-01-22 09:48:59 +00:00
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
2018-03-09 17:51:32 +00:00
self.min_seq_len = min_seq_len
self.ap = ap
2018-11-02 15:13:51 +00:00
print(" > Reading LJSpeech from - {}".format(root_path))
print(" | > Number of instances : {}".format(len(self.items)))
self.sort_items()
2018-01-22 09:48:59 +00:00
def load_wav(self, filename):
try:
2018-11-02 15:13:51 +00:00
audio = self.ap.load_wav(filename)
2018-01-22 09:48:59 +00:00
return audio
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
2018-11-02 15:13:51 +00:00
def sort_items(self):
2018-09-20 09:08:12 +00:00
r"""Sort text sequences in ascending order"""
2018-11-02 15:13:51 +00:00
lengths = np.array([len(ins[0]) for ins in self.items])
2018-04-03 10:24:57 +00:00
2018-03-07 14:58:51 +00:00
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths)))
2018-04-03 10:24:57 +00:00
2018-03-07 14:58:51 +00:00
idxs = np.argsort(lengths)
2018-11-02 15:13:51 +00:00
new_items = []
2018-03-09 17:46:47 +00:00
ignored = []
2018-03-07 14:58:51 +00:00
for i, idx in enumerate(idxs):
2018-03-09 17:46:47 +00:00
length = lengths[idx]
2018-03-09 17:51:32 +00:00
if length < self.min_seq_len:
2018-03-09 17:46:47 +00:00
ignored.append(idx)
2018-03-09 17:49:18 +00:00
else:
2018-11-02 15:13:51 +00:00
new_items.append(self.items[idx])
2018-04-03 10:24:57 +00:00
print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len))
2018-09-20 09:08:12 +00:00
# shuffle batch groups
if self.batch_group_size > 0:
print(" | > Batch group shuffling is active.")
2018-11-02 15:13:51 +00:00
for i in range(len(new_items) // self.batch_group_size):
2018-09-20 09:08:12 +00:00
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
2018-11-02 15:13:51 +00:00
temp_items = new_items[offset : end_offset]
random.shuffle(temp_items)
new_items[offset : end_offset] = temp_items
self.items = new_items
2018-04-03 10:24:57 +00:00
2018-01-22 09:48:59 +00:00
def __len__(self):
2018-11-02 15:13:51 +00:00
return len(self.items)
2018-01-22 09:48:59 +00:00
def __getitem__(self, idx):
2018-11-02 15:13:51 +00:00
text, wav_file = self.items[idx]
2018-08-02 14:34:17 +00:00
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
2018-11-02 15:13:51 +00:00
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]}
2018-01-22 09:48:59 +00:00
return sample
def collate_fn(self, batch):
2018-03-07 14:58:51 +00:00
r"""
Perform preprocessing and create a final data batch:
1. PAD sequences with the longest sequence in the batch
2. Convert Audio signal to Spectrograms.
3. PAD sequences that can be divided by r.
4. Convert Numpy to Torch tensors.
"""
2018-01-22 09:48:59 +00:00
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
keys = list()
wav = [d['wav'] for d in batch]
item_idxs = [d['item_idx'] for d in batch]
2018-01-22 09:48:59 +00:00
text = [d['text'] for d in batch]
text_lenghts = np.array([len(x) for x in text])
2018-02-02 13:37:09 +00:00
max_text_len = np.max(text_lenghts)
2018-01-22 09:48:59 +00:00
2018-03-22 19:34:16 +00:00
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
2018-04-03 10:24:57 +00:00
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
2018-03-22 21:06:33 +00:00
# compute 'stop token' targets
2018-08-02 14:34:17 +00:00
stop_targets = [
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
2018-04-03 10:24:57 +00:00
2018-03-22 21:06:33 +00:00
# PAD stop targets
2018-08-02 14:34:17 +00:00
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
2018-03-22 19:34:16 +00:00
2018-01-22 09:48:59 +00:00
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
# PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
2018-04-03 10:24:57 +00:00
timesteps = mel.shape[2]
2018-01-22 09:48:59 +00:00
2018-03-22 20:46:52 +00:00
# B x T x D
linear = linear.transpose(0, 2, 1)
2018-01-22 14:58:12 +00:00
mel = mel.transpose(0, 2, 1)
2018-02-09 13:39:58 +00:00
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
2018-11-20 13:56:19 +00:00
linear = torch.FloatTensor(linear).contiguous()
mel = torch.FloatTensor(mel).contiguous()
2018-03-22 20:46:52 +00:00
mel_lengths = torch.LongTensor(mel_lengths)
2018-03-22 19:34:16 +00:00
stop_targets = torch.FloatTensor(stop_targets)
2018-04-03 10:24:57 +00:00
2018-10-04 11:53:39 +00:00
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
2018-01-22 09:48:59 +00:00
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
2018-08-02 14:34:17 +00:00
found {}".format(type(batch[0]))))