Add cached option to TTSDataset.py, depricating TTSDatasetCached

pull/10/head
Eren Golge 2018-12-17 16:32:45 +01:00
parent 6651b2ccb1
commit 2f0e9545a3
2 changed files with 69 additions and 18 deletions

View File

@ -20,7 +20,27 @@ class MyDataset(Dataset):
ap,
preprocessor,
batch_group_size=0,
min_seq_len=0):
min_seq_len=0,
max_seq_len=float("inf"),
cached=False):
"""
Args:
root_path (str): root path for the data folder.
meta_file (str): name for dataset file including audio transcripts
and file names (or paths in cached mode).
outputs_per_step (int): number of time frames predicted per step.
text_cleaner (str): text cleaner used for the dataset.
ap (TTS.utils.AudioProcessor): audio processor object.
preprocessor (dataset.preprocess.Class): preprocessor for the dataset.
Create your own if you need to run a new dataset.
batch_group_size (int): (0) range of batch randomization after sorting
sequences by length.
min_seq_len (int): (0) minimum sequence length to be processed
by the loader.
max_seq_len (int): (float("inf")) maximum sequence length.
cached (bool): (false) true if the given data path is created
by extract_features.py.
"""
self.root_path = root_path
self.batch_group_size = batch_group_size
self.items = preprocessor(root_path, meta_file)
@ -28,9 +48,14 @@ class MyDataset(Dataset):
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
print(" > Reading LJSpeech from - {}".format(root_path))
self.cached = cached
print(" > DataLoader initialization")
print(" | > Data path: {}".format(root_path))
print(" | > Cached dataset: {}".format(self.cached))
print(" | > Number of instances : {}".format(len(self.items)))
self.sort_items()
def load_wav(self, filename):
@ -40,24 +65,51 @@ class MyDataset(Dataset):
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def load_np(self, filename):
data = np.load(filename).astype('float32')
return data
def load_data(self, idx):
if self.cached:
wav_name = self.items[idx][1]
mel_name = self.items[idx][2]
linear_name = self.items[idx][3]
text = self.items[idx][0]
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
if wav_name.split('.')[-1] == 'npy':
wav = self.load_np(wav_name)
else:
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
mel = self.load_np(mel_name)
linear = self.load_np(linear_name)
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1], 'mel':mel, 'linear': linear}
else:
text, wav_file = self.items[idx]
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]}
return sample
def sort_items(self):
r"""Sort text sequences in ascending order"""
r"""Sort instances based on text length in ascending order"""
lengths = np.array([len(ins[0]) for ins in self.items])
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths)))
print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths)))
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
idxs = np.argsort(lengths)
new_items = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len:
if length < self.min_seq_len or length > self.max_seq_len:
ignored.append(idx)
else:
new_items.append(self.items[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format(
print(" | > {} instances are ignored ({})".format(
len(ignored), self.min_seq_len))
# shuffle batch groups
if self.batch_group_size > 0:
@ -74,12 +126,7 @@ class MyDataset(Dataset):
return len(self.items)
def __getitem__(self, idx):
text, wav_file = self.items[idx]
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]}
return sample
return self.load_data(idx)
def collate_fn(self, batch):
r"""
@ -101,8 +148,12 @@ class MyDataset(Dataset):
text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts)
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
if self.cached:
mel = [d['mel'] for d in batch]
linear = [d['linear'] for d in batch]
else:
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets

View File

@ -151,8 +151,8 @@ class MyDataset(Dataset):
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear)
mel = torch.FloatTensor(mel)
linear = torch.FloatTensor(linear).contiguous()
mel = torch.FloatTensor(mel).contiguous()
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)