enable loading precomputed vocoder dataset

pull/1/head
erogol 2020-06-12 11:12:57 +02:00
parent 92ed9b90c3
commit 1c2dc9f739
3 changed files with 42 additions and 8 deletions

View File

@ -28,6 +28,7 @@ class GANDataset(Dataset):
self.ap = ap
self.item_list = items
self.compute_feat = not isinstance(items[0], (tuple, list))
self.seq_len = seq_len
self.hop_len = hop_len
self.pad_short = pad_short
@ -77,14 +78,26 @@ class GANDataset(Dataset):
def load_item(self, idx):
""" load (audio, feat) couple """
wavpath = self.item_list[idx]
# print(wavpath)
if self.compute_feat:
# compute features from wav
wavpath = self.item_list[idx]
# print(wavpath)
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
mel = self.ap.melspectrogram(audio)
else:
audio = self.ap.load_wav(wavpath)
mel = self.ap.melspectrogram(audio)
# load precomputed features
wavpath, feat_path = self.item_list[idx]
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
mel = np.load(feat_path)
if len(audio) < self.seq_len + self.pad_short:
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \

View File

@ -1,5 +1,6 @@
import glob
import os
from pathlib import Path
import numpy as np
@ -9,8 +10,28 @@ def find_wav_files(data_path):
return wav_paths
def find_feat_files(data_path):
feat_paths = glob.glob(os.path.join(data_path, '**', '*.npy'), recursive=True)
return feat_paths
def load_wav_data(data_path, eval_split_size):
wav_paths = find_wav_files(data_path)
np.random.seed(0)
np.random.shuffle(wav_paths)
return wav_paths[:eval_split_size], wav_paths[eval_split_size:]
def load_wav_feat_data(data_path, feat_path, eval_split_size):
wav_paths = sorted(find_wav_files(data_path))
feat_paths = sorted(find_feat_files(feat_path))
assert len(wav_paths) == len(feat_paths)
for wav, feat in zip(wav_paths, feat_paths):
wav_name = Path(wav).stem
feat_name = Path(feat).stem
assert wav_name == feat_name
items = list(zip(wav_paths, feat_paths))
np.random.seed(0)
np.random.shuffle(items)
return items[:eval_split_size], items[eval_split_size:]

View File

@ -19,7 +19,7 @@ from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
# from distribute import (DistributedSampler, apply_gradient_allreduce,
# init_distributed, reduce_tensor)
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
@ -543,8 +543,8 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = save_best_model(target_loss,
best_loss,
model_gen,
scheduler_gen,
optimizer_gen,
scheduler_gen,
model_disc,
optimizer_disc,
scheduler_disc,