Add configurable phoneme cache path

pull/10/head
Eren Golge 2019-01-15 15:51:13 +01:00
parent 7af1aeaf7a
commit 5733eab90b
2 changed files with 8 additions and 3 deletions

View File

@ -22,7 +22,8 @@ class MyDataset(Dataset):
batch_group_size=0, batch_group_size=0,
min_seq_len=0, min_seq_len=0,
max_seq_len=float("inf"), max_seq_len=float("inf"),
cached=False): cached=False,
phoneme_cache_path=None):
""" """
Args: Args:
root_path (str): root path for the data folder. root_path (str): root path for the data folder.
@ -40,6 +41,7 @@ class MyDataset(Dataset):
max_seq_len (int): (float("inf")) maximum sequence length. max_seq_len (int): (float("inf")) maximum sequence length.
cached (bool): (false) true if the given data path is created cached (bool): (false) true if the given data path is created
by extract_features.py. by extract_features.py.
phoneme_cache_path (str): path to cache phoneme features.
""" """
self.root_path = root_path self.root_path = root_path
self.batch_group_size = batch_group_size self.batch_group_size = batch_group_size
@ -51,6 +53,7 @@ class MyDataset(Dataset):
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.ap = ap self.ap = ap
self.cached = cached self.cached = cached
self.phoneme_cache_path = phoneme_cache_path
print(" > DataLoader initialization") print(" > DataLoader initialization")
print(" | > Data path: {}".format(root_path)) print(" | > Data path: {}".format(root_path))
print(" | > Cached dataset: {}".format(self.cached)) print(" | > Cached dataset: {}".format(self.cached))
@ -87,7 +90,7 @@ class MyDataset(Dataset):
else: else:
text, wav_file = self.items[idx] text, wav_file = self.items[idx]
file_name = os.path.basename(wav_file).split('.')[0] file_name = os.path.basename(wav_file).split('.')[0]
tmp_path = os.path.join("tmp/",file_name+'_phoneme.npy') tmp_path = os.path.join(self.phoneme_cache_path, file_name+'_phoneme.npy')
if os.path.isfile(tmp_path): if os.path.isfile(tmp_path):
text = np.load(tmp_path) text = np.load(tmp_path)
else: else:

View File

@ -46,7 +46,9 @@ def setup_loader(is_val=False):
batch_group_size=0 if is_val else 8 * c.batch_size, batch_group_size=0 if is_val else 8 * c.batch_size,
min_seq_len=0 if is_val else c.min_seq_len, min_seq_len=0 if is_val else c.min_seq_len,
max_seq_len=float("inf") if is_val else c.max_seq_len, max_seq_len=float("inf") if is_val else c.max_seq_len,
cached=False if c.dataset != "tts_cache" else True) cached=False if c.dataset != "tts_cache" else True,
phoneme_cache_path=c.phoneme_cache_path
)
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=c.eval_batch_size if is_val else c.batch_size, batch_size=c.eval_batch_size if is_val else c.batch_size,