diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index a9d111b0..e97b38af 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -20,7 +20,27 @@ class MyDataset(Dataset): ap, preprocessor, batch_group_size=0, - min_seq_len=0): + min_seq_len=0, + max_seq_len=float("inf"), + cached=False): + """ + Args: + root_path (str): root path for the data folder. + meta_file (str): name for dataset file including audio transcripts + and file names (or paths in cached mode). + outputs_per_step (int): number of time frames predicted per step. + text_cleaner (str): text cleaner used for the dataset. + ap (TTS.utils.AudioProcessor): audio processor object. + preprocessor (dataset.preprocess.Class): preprocessor for the dataset. + Create your own if you need to run a new dataset. + batch_group_size (int): (0) range of batch randomization after sorting + sequences by length. + min_seq_len (int): (0) minimum sequence length to be processed + by the loader. + max_seq_len (int): (float("inf")) maximum sequence length. + cached (bool): (false) true if the given data path is created + by extract_features.py. + """ self.root_path = root_path self.batch_group_size = batch_group_size self.items = preprocessor(root_path, meta_file) @@ -28,9 +48,14 @@ class MyDataset(Dataset): self.sample_rate = ap.sample_rate self.cleaners = text_cleaner self.min_seq_len = min_seq_len + self.max_seq_len = max_seq_len self.ap = ap - print(" > Reading LJSpeech from - {}".format(root_path)) + self.cached = cached + print(" > DataLoader initialization") + print(" | > Data path: {}".format(root_path)) + print(" | > Cached dataset: {}".format(self.cached)) print(" | > Number of instances : {}".format(len(self.items))) + self.sort_items() def load_wav(self, filename): @@ -40,24 +65,51 @@ class MyDataset(Dataset): except RuntimeError as e: print(" !! Cannot read file : {}".format(filename)) + def load_np(self, filename): + data = np.load(filename).astype('float32') + return data + + def load_data(self, idx): + if self.cached: + wav_name = self.items[idx][1] + mel_name = self.items[idx][2] + linear_name = self.items[idx][3] + text = self.items[idx][0] + text = np.asarray( + text_to_sequence(text, [self.cleaners]), dtype=np.int32) + if wav_name.split('.')[-1] == 'npy': + wav = self.load_np(wav_name) + else: + wav = np.asarray(self.load_wav(wav_name), dtype=np.float32) + mel = self.load_np(mel_name) + linear = self.load_np(linear_name) + sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1], 'mel':mel, 'linear': linear} + else: + text, wav_file = self.items[idx] + text = np.asarray( + text_to_sequence(text, [self.cleaners]), dtype=np.int32) + wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) + sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]} + return sample + def sort_items(self): - r"""Sort text sequences in ascending order""" + r"""Sort instances based on text length in ascending order""" lengths = np.array([len(ins[0]) for ins in self.items]) - print(" | > Max length sequence {}".format(np.max(lengths))) - print(" | > Min length sequence {}".format(np.min(lengths))) - print(" | > Avg length sequence {}".format(np.mean(lengths))) + print(" | > Max length sequence: {}".format(np.max(lengths))) + print(" | > Min length sequence: {}".format(np.min(lengths))) + print(" | > Avg length sequence: {}".format(np.mean(lengths))) idxs = np.argsort(lengths) new_items = [] ignored = [] for i, idx in enumerate(idxs): length = lengths[idx] - if length < self.min_seq_len: + if length < self.min_seq_len or length > self.max_seq_len: ignored.append(idx) else: new_items.append(self.items[idx]) - print(" | > {} instances are ignored by min_seq_len ({})".format( + print(" | > {} instances are ignored ({})".format( len(ignored), self.min_seq_len)) # shuffle batch groups if self.batch_group_size > 0: @@ -74,12 +126,7 @@ class MyDataset(Dataset): return len(self.items) def __getitem__(self, idx): - text, wav_file = self.items[idx] - text = np.asarray( - text_to_sequence(text, [self.cleaners]), dtype=np.int32) - wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) - sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]} - return sample + return self.load_data(idx) def collate_fn(self, batch): r""" @@ -101,8 +148,12 @@ class MyDataset(Dataset): text_lenghts = np.array([len(x) for x in text]) max_text_len = np.max(text_lenghts) - linear = [self.ap.spectrogram(w).astype('float32') for w in wav] - mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] + if self.cached: + mel = [d['mel'] for d in batch] + linear = [d['linear'] for d in batch] + else: + mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] + linear = [self.ap.spectrogram(w).astype('float32') for w in wav] mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame # compute 'stop token' targets diff --git a/datasets/TTSDatasetCached.py b/datasets/TTSDatasetCached.py index b5c6d4ce..28033a80 100644 --- a/datasets/TTSDatasetCached.py +++ b/datasets/TTSDatasetCached.py @@ -151,8 +151,8 @@ class MyDataset(Dataset): # convert things to pytorch text_lenghts = torch.LongTensor(text_lenghts) text = torch.LongTensor(text) - linear = torch.FloatTensor(linear) - mel = torch.FloatTensor(mel) + linear = torch.FloatTensor(linear).contiguous() + mel = torch.FloatTensor(mel).contiguous() mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets)