mirror of https://github.com/coqui-ai/TTS.git
enable loading precomputed vocoder dataset
parent
92ed9b90c3
commit
1c2dc9f739
|
@ -28,6 +28,7 @@ class GANDataset(Dataset):
|
|||
|
||||
self.ap = ap
|
||||
self.item_list = items
|
||||
self.compute_feat = not isinstance(items[0], (tuple, list))
|
||||
self.seq_len = seq_len
|
||||
self.hop_len = hop_len
|
||||
self.pad_short = pad_short
|
||||
|
@ -77,14 +78,26 @@ class GANDataset(Dataset):
|
|||
|
||||
def load_item(self, idx):
|
||||
""" load (audio, feat) couple """
|
||||
wavpath = self.item_list[idx]
|
||||
# print(wavpath)
|
||||
if self.compute_feat:
|
||||
# compute features from wav
|
||||
wavpath = self.item_list[idx]
|
||||
# print(wavpath)
|
||||
|
||||
if self.use_cache and self.cache[idx] is not None:
|
||||
audio, mel = self.cache[idx]
|
||||
if self.use_cache and self.cache[idx] is not None:
|
||||
audio, mel = self.cache[idx]
|
||||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
|
||||
# load precomputed features
|
||||
wavpath, feat_path = self.item_list[idx]
|
||||
|
||||
if self.use_cache and self.cache[idx] is not None:
|
||||
audio, mel = self.cache[idx]
|
||||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
mel = np.load(feat_path)
|
||||
|
||||
if len(audio) < self.seq_len + self.pad_short:
|
||||
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import glob
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -9,8 +10,28 @@ def find_wav_files(data_path):
|
|||
return wav_paths
|
||||
|
||||
|
||||
def find_feat_files(data_path):
|
||||
feat_paths = glob.glob(os.path.join(data_path, '**', '*.npy'), recursive=True)
|
||||
return feat_paths
|
||||
|
||||
|
||||
def load_wav_data(data_path, eval_split_size):
|
||||
wav_paths = find_wav_files(data_path)
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(wav_paths)
|
||||
return wav_paths[:eval_split_size], wav_paths[eval_split_size:]
|
||||
|
||||
|
||||
def load_wav_feat_data(data_path, feat_path, eval_split_size):
|
||||
wav_paths = sorted(find_wav_files(data_path))
|
||||
feat_paths = sorted(find_feat_files(feat_path))
|
||||
assert len(wav_paths) == len(feat_paths)
|
||||
for wav, feat in zip(wav_paths, feat_paths):
|
||||
wav_name = Path(wav).stem
|
||||
feat_name = Path(feat).stem
|
||||
assert wav_name == feat_name
|
||||
|
||||
items = list(zip(wav_paths, feat_paths))
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(items)
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
|
|
@ -19,7 +19,7 @@ from TTS.utils.radam import RAdam
|
|||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||
# init_distributed, reduce_tensor)
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
|
@ -543,8 +543,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
best_loss = save_best_model(target_loss,
|
||||
best_loss,
|
||||
model_gen,
|
||||
scheduler_gen,
|
||||
optimizer_gen,
|
||||
scheduler_gen,
|
||||
model_disc,
|
||||
optimizer_disc,
|
||||
scheduler_disc,
|
||||
|
|
Loading…
Reference in New Issue