mirror of https://github.com/coqui-ai/TTS.git
Add cached option to TTSDataset.py, depricating TTSDatasetCached
parent
6651b2ccb1
commit
2f0e9545a3
|
@ -20,7 +20,27 @@ class MyDataset(Dataset):
|
|||
ap,
|
||||
preprocessor,
|
||||
batch_group_size=0,
|
||||
min_seq_len=0):
|
||||
min_seq_len=0,
|
||||
max_seq_len=float("inf"),
|
||||
cached=False):
|
||||
"""
|
||||
Args:
|
||||
root_path (str): root path for the data folder.
|
||||
meta_file (str): name for dataset file including audio transcripts
|
||||
and file names (or paths in cached mode).
|
||||
outputs_per_step (int): number of time frames predicted per step.
|
||||
text_cleaner (str): text cleaner used for the dataset.
|
||||
ap (TTS.utils.AudioProcessor): audio processor object.
|
||||
preprocessor (dataset.preprocess.Class): preprocessor for the dataset.
|
||||
Create your own if you need to run a new dataset.
|
||||
batch_group_size (int): (0) range of batch randomization after sorting
|
||||
sequences by length.
|
||||
min_seq_len (int): (0) minimum sequence length to be processed
|
||||
by the loader.
|
||||
max_seq_len (int): (float("inf")) maximum sequence length.
|
||||
cached (bool): (false) true if the given data path is created
|
||||
by extract_features.py.
|
||||
"""
|
||||
self.root_path = root_path
|
||||
self.batch_group_size = batch_group_size
|
||||
self.items = preprocessor(root_path, meta_file)
|
||||
|
@ -28,9 +48,14 @@ class MyDataset(Dataset):
|
|||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
print(" > Reading LJSpeech from - {}".format(root_path))
|
||||
self.cached = cached
|
||||
print(" > DataLoader initialization")
|
||||
print(" | > Data path: {}".format(root_path))
|
||||
print(" | > Cached dataset: {}".format(self.cached))
|
||||
print(" | > Number of instances : {}".format(len(self.items)))
|
||||
|
||||
self.sort_items()
|
||||
|
||||
def load_wav(self, filename):
|
||||
|
@ -40,24 +65,51 @@ class MyDataset(Dataset):
|
|||
except RuntimeError as e:
|
||||
print(" !! Cannot read file : {}".format(filename))
|
||||
|
||||
def load_np(self, filename):
|
||||
data = np.load(filename).astype('float32')
|
||||
return data
|
||||
|
||||
def load_data(self, idx):
|
||||
if self.cached:
|
||||
wav_name = self.items[idx][1]
|
||||
mel_name = self.items[idx][2]
|
||||
linear_name = self.items[idx][3]
|
||||
text = self.items[idx][0]
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
if wav_name.split('.')[-1] == 'npy':
|
||||
wav = self.load_np(wav_name)
|
||||
else:
|
||||
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
||||
mel = self.load_np(mel_name)
|
||||
linear = self.load_np(linear_name)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1], 'mel':mel, 'linear': linear}
|
||||
else:
|
||||
text, wav_file = self.items[idx]
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]}
|
||||
return sample
|
||||
|
||||
def sort_items(self):
|
||||
r"""Sort text sequences in ascending order"""
|
||||
r"""Sort instances based on text length in ascending order"""
|
||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||
|
||||
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
|
||||
|
||||
idxs = np.argsort(lengths)
|
||||
new_items = []
|
||||
ignored = []
|
||||
for i, idx in enumerate(idxs):
|
||||
length = lengths[idx]
|
||||
if length < self.min_seq_len:
|
||||
if length < self.min_seq_len or length > self.max_seq_len:
|
||||
ignored.append(idx)
|
||||
else:
|
||||
new_items.append(self.items[idx])
|
||||
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||
print(" | > {} instances are ignored ({})".format(
|
||||
len(ignored), self.min_seq_len))
|
||||
# shuffle batch groups
|
||||
if self.batch_group_size > 0:
|
||||
|
@ -74,12 +126,7 @@ class MyDataset(Dataset):
|
|||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
text, wav_file = self.items[idx]
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]}
|
||||
return sample
|
||||
return self.load_data(idx)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
r"""
|
||||
|
@ -101,8 +148,12 @@ class MyDataset(Dataset):
|
|||
text_lenghts = np.array([len(x) for x in text])
|
||||
max_text_len = np.max(text_lenghts)
|
||||
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
if self.cached:
|
||||
mel = [d['mel'] for d in batch]
|
||||
linear = [d['linear'] for d in batch]
|
||||
else:
|
||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
|
|
|
@ -151,8 +151,8 @@ class MyDataset(Dataset):
|
|||
# convert things to pytorch
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text = torch.LongTensor(text)
|
||||
linear = torch.FloatTensor(linear)
|
||||
mel = torch.FloatTensor(mel)
|
||||
linear = torch.FloatTensor(linear).contiguous()
|
||||
mel = torch.FloatTensor(mel).contiguous()
|
||||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
|
|
Loading…
Reference in New Issue