mirror of https://github.com/coqui-ai/TTS.git
enable loading precomputed vocoder dataset
parent
92ed9b90c3
commit
1c2dc9f739
|
@ -28,6 +28,7 @@ class GANDataset(Dataset):
|
||||||
|
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.item_list = items
|
self.item_list = items
|
||||||
|
self.compute_feat = not isinstance(items[0], (tuple, list))
|
||||||
self.seq_len = seq_len
|
self.seq_len = seq_len
|
||||||
self.hop_len = hop_len
|
self.hop_len = hop_len
|
||||||
self.pad_short = pad_short
|
self.pad_short = pad_short
|
||||||
|
@ -77,6 +78,8 @@ class GANDataset(Dataset):
|
||||||
|
|
||||||
def load_item(self, idx):
|
def load_item(self, idx):
|
||||||
""" load (audio, feat) couple """
|
""" load (audio, feat) couple """
|
||||||
|
if self.compute_feat:
|
||||||
|
# compute features from wav
|
||||||
wavpath = self.item_list[idx]
|
wavpath = self.item_list[idx]
|
||||||
# print(wavpath)
|
# print(wavpath)
|
||||||
|
|
||||||
|
@ -85,6 +88,16 @@ class GANDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
audio = self.ap.load_wav(wavpath)
|
audio = self.ap.load_wav(wavpath)
|
||||||
mel = self.ap.melspectrogram(audio)
|
mel = self.ap.melspectrogram(audio)
|
||||||
|
else:
|
||||||
|
|
||||||
|
# load precomputed features
|
||||||
|
wavpath, feat_path = self.item_list[idx]
|
||||||
|
|
||||||
|
if self.use_cache and self.cache[idx] is not None:
|
||||||
|
audio, mel = self.cache[idx]
|
||||||
|
else:
|
||||||
|
audio = self.ap.load_wav(wavpath)
|
||||||
|
mel = np.load(feat_path)
|
||||||
|
|
||||||
if len(audio) < self.seq_len + self.pad_short:
|
if len(audio) < self.seq_len + self.pad_short:
|
||||||
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
|
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -9,8 +10,28 @@ def find_wav_files(data_path):
|
||||||
return wav_paths
|
return wav_paths
|
||||||
|
|
||||||
|
|
||||||
|
def find_feat_files(data_path):
|
||||||
|
feat_paths = glob.glob(os.path.join(data_path, '**', '*.npy'), recursive=True)
|
||||||
|
return feat_paths
|
||||||
|
|
||||||
|
|
||||||
def load_wav_data(data_path, eval_split_size):
|
def load_wav_data(data_path, eval_split_size):
|
||||||
wav_paths = find_wav_files(data_path)
|
wav_paths = find_wav_files(data_path)
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
np.random.shuffle(wav_paths)
|
np.random.shuffle(wav_paths)
|
||||||
return wav_paths[:eval_split_size], wav_paths[eval_split_size:]
|
return wav_paths[:eval_split_size], wav_paths[eval_split_size:]
|
||||||
|
|
||||||
|
|
||||||
|
def load_wav_feat_data(data_path, feat_path, eval_split_size):
|
||||||
|
wav_paths = sorted(find_wav_files(data_path))
|
||||||
|
feat_paths = sorted(find_feat_files(feat_path))
|
||||||
|
assert len(wav_paths) == len(feat_paths)
|
||||||
|
for wav, feat in zip(wav_paths, feat_paths):
|
||||||
|
wav_name = Path(wav).stem
|
||||||
|
feat_name = Path(feat).stem
|
||||||
|
assert wav_name == feat_name
|
||||||
|
|
||||||
|
items = list(zip(wav_paths, feat_paths))
|
||||||
|
np.random.seed(0)
|
||||||
|
np.random.shuffle(items)
|
||||||
|
return items[:eval_split_size], items[eval_split_size:]
|
||||||
|
|
|
@ -19,7 +19,7 @@ from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.training import setup_torch_training_env
|
from TTS.utils.training import setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||||
# init_distributed, reduce_tensor)
|
# init_distributed, reduce_tensor)
|
||||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||||
|
@ -543,8 +543,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
best_loss = save_best_model(target_loss,
|
best_loss = save_best_model(target_loss,
|
||||||
best_loss,
|
best_loss,
|
||||||
model_gen,
|
model_gen,
|
||||||
scheduler_gen,
|
|
||||||
optimizer_gen,
|
optimizer_gen,
|
||||||
|
scheduler_gen,
|
||||||
model_disc,
|
model_disc,
|
||||||
optimizer_disc,
|
optimizer_disc,
|
||||||
scheduler_disc,
|
scheduler_disc,
|
||||||
|
|
Loading…
Reference in New Issue