diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index 6722b510..36e918d1 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -22,7 +22,8 @@ class MyDataset(Dataset): batch_group_size=0, min_seq_len=0, max_seq_len=float("inf"), - cached=False): + cached=False, + phoneme_cache_path=None): """ Args: root_path (str): root path for the data folder. @@ -40,6 +41,7 @@ class MyDataset(Dataset): max_seq_len (int): (float("inf")) maximum sequence length. cached (bool): (false) true if the given data path is created by extract_features.py. + phoneme_cache_path (str): path to cache phoneme features. """ self.root_path = root_path self.batch_group_size = batch_group_size @@ -51,6 +53,7 @@ class MyDataset(Dataset): self.max_seq_len = max_seq_len self.ap = ap self.cached = cached + self.phoneme_cache_path = phoneme_cache_path print(" > DataLoader initialization") print(" | > Data path: {}".format(root_path)) print(" | > Cached dataset: {}".format(self.cached)) @@ -87,7 +90,7 @@ class MyDataset(Dataset): else: text, wav_file = self.items[idx] file_name = os.path.basename(wav_file).split('.')[0] - tmp_path = os.path.join("tmp/",file_name+'_phoneme.npy') + tmp_path = os.path.join(self.phoneme_cache_path, file_name+'_phoneme.npy') if os.path.isfile(tmp_path): text = np.load(tmp_path) else: diff --git a/train.py b/train.py index 2be8071f..3e7ad0ad 100644 --- a/train.py +++ b/train.py @@ -46,7 +46,9 @@ def setup_loader(is_val=False): batch_group_size=0 if is_val else 8 * c.batch_size, min_seq_len=0 if is_val else c.min_seq_len, max_seq_len=float("inf") if is_val else c.max_seq_len, - cached=False if c.dataset != "tts_cache" else True) + cached=False if c.dataset != "tts_cache" else True, + phoneme_cache_path=c.phoneme_cache_path + ) loader = DataLoader( dataset, batch_size=c.eval_batch_size if is_val else c.batch_size,