From ec83ffbd7ad03b965824bcc86f8171e574831006 Mon Sep 17 00:00:00 2001 From: Julian WEBER Date: Wed, 27 Oct 2021 13:40:11 +0200 Subject: [PATCH] PitchExtractor --- TTS/tts/datasets/dataset.py | 106 ++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index ccfa70f1..635ffb38 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -105,6 +105,7 @@ class TTSDataset(Dataset): self.cleaners = text_cleaner self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav + self.compute_f0 = compute_f0 self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap @@ -538,3 +539,108 @@ class TTSDataset(Dataset): ) ) ) + +class PitchExtractor: + """Pitch Extractor for computing F0 from wav files. + Args: + items (List[List]): Dataset samples. + verbose (bool): Whether to print the progress. + """ + + def __init__( + self, + items: List[List], + verbose=False, + ): + self.items = items + self.verbose = verbose + self.mean = None + self.std = None + + @staticmethod + def create_pitch_file_path(wav_file, cache_path): + file_name = os.path.splitext(os.path.basename(wav_file))[0] + pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") + return pitch_file + + @staticmethod + def _compute_and_save_pitch(ap, wav_file, pitch_file=None): + wav = ap.load_wav(wav_file) + pitch = ap.compute_f0(wav) + if pitch_file: + np.save(pitch_file, pitch) + return pitch + + @staticmethod + def compute_pitch_stats(pitch_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def normalize_pitch(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch = pitch - self.mean + pitch = pitch / self.std + pitch[zero_idxs] = 0.0 + return pitch + + def denormalize_pitch(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch *= self.std + pitch += self.mean + pitch[zero_idxs] = 0.0 + return pitch + + @staticmethod + def load_or_compute_pitch(ap, wav_file, cache_path): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + @staticmethod + def _pitch_worker(args): + item = args[0] + ap = args[1] + cache_path = args[2] + _, wav_file, *_ = item + pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path) + if not os.path.exists(pitch_file): + pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) + return pitch + return None + + def compute_pitch(self, ap, cache_path, num_workers=0): + """Compute the input sequences with multi-processing. + Call it before passing dataset to the data loader to cache the input sequences for faster data loading.""" + if not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + + if self.verbose: + print(" | > Computing pitch features ...") + if num_workers == 0: + pitch_vecs = [] + for _, item in enumerate(tqdm.tqdm(self.items)): + pitch_vecs += [self._pitch_worker([item, ap, cache_path])] + else: + with Pool(num_workers) as p: + pitch_vecs = list( + tqdm.tqdm( + p.imap(PitchExtractor._pitch_worker, [[item, ap, cache_path] for item in self.items]), + total=len(self.items), + ) + ) + pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def load_pitch_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) \ No newline at end of file