Add configurable phoneme cache path

pull/10/head
Eren Golge 2019-01-15 15:51:13 +01:00
parent 7af1aeaf7a
commit 5733eab90b
2 changed files with 8 additions and 3 deletions

View File

@ -22,7 +22,8 @@ class MyDataset(Dataset):
batch_group_size=0,
min_seq_len=0,
max_seq_len=float("inf"),
cached=False):
cached=False,
phoneme_cache_path=None):
"""
Args:
root_path (str): root path for the data folder.
@ -40,6 +41,7 @@ class MyDataset(Dataset):
max_seq_len (int): (float("inf")) maximum sequence length.
cached (bool): (false) true if the given data path is created
by extract_features.py.
phoneme_cache_path (str): path to cache phoneme features.
"""
self.root_path = root_path
self.batch_group_size = batch_group_size
@ -51,6 +53,7 @@ class MyDataset(Dataset):
self.max_seq_len = max_seq_len
self.ap = ap
self.cached = cached
self.phoneme_cache_path = phoneme_cache_path
print(" > DataLoader initialization")
print(" | > Data path: {}".format(root_path))
print(" | > Cached dataset: {}".format(self.cached))
@ -87,7 +90,7 @@ class MyDataset(Dataset):
else:
text, wav_file = self.items[idx]
file_name = os.path.basename(wav_file).split('.')[0]
tmp_path = os.path.join("tmp/",file_name+'_phoneme.npy')
tmp_path = os.path.join(self.phoneme_cache_path, file_name+'_phoneme.npy')
if os.path.isfile(tmp_path):
text = np.load(tmp_path)
else:

View File

@ -46,7 +46,9 @@ def setup_loader(is_val=False):
batch_group_size=0 if is_val else 8 * c.batch_size,
min_seq_len=0 if is_val else c.min_seq_len,
max_seq_len=float("inf") if is_val else c.max_seq_len,
cached=False if c.dataset != "tts_cache" else True)
cached=False if c.dataset != "tts_cache" else True,
phoneme_cache_path=c.phoneme_cache_path
)
loader = DataLoader(
dataset,
batch_size=c.eval_batch_size if is_val else c.batch_size,