diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py index d8cc350a..0f69b812 100644 --- a/TTS/vocoder/datasets/preprocess.py +++ b/TTS/vocoder/datasets/preprocess.py @@ -33,8 +33,8 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): np.save(quant_path, quant) -def find_wav_files(data_path): - wav_paths = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) +def find_wav_files(data_path, file_ext="wav"): + wav_paths = glob.glob(os.path.join(data_path, "**", f"*.{file_ext}"), recursive=True) return wav_paths @@ -43,8 +43,9 @@ def find_feat_files(data_path): return feat_paths -def load_wav_data(data_path, eval_split_size): - wav_paths = find_wav_files(data_path) +def load_wav_data(data_path, eval_split_size, file_ext="wav"): + wav_paths = find_wav_files(data_path, file_ext=file_ext) + assert len(wav_paths) > 0, f" [!] {data_path} is empty." np.random.seed(0) np.random.shuffle(wav_paths) return wav_paths[:eval_split_size], wav_paths[eval_split_size:]