From 1c4e806f549923169056ae90c51795ffae772f65 Mon Sep 17 00:00:00 2001 From: Edresson Date: Sun, 27 Jun 2021 03:35:34 -0300 Subject: [PATCH] use speaker manager on compute embeddings script --- TTS/bin/compute_embeddings.py | 39 +++++----------------------- TTS/speaker_encoder/models/lstm.py | 4 ++- TTS/speaker_encoder/models/resnet.py | 9 +++++++ TTS/tts/utils/speakers.py | 12 ++++++--- 4 files changed, 27 insertions(+), 37 deletions(-) diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 5332123d..e843150b 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -1,15 +1,11 @@ import argparse import os -import torch -import numpy as np from tqdm import tqdm from TTS.config import load_config -from TTS.speaker_encoder.utils.generic_utils import setup_model from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.audio import AudioProcessor parser = argparse.ArgumentParser( description='Compute embedding vectors for each wav file in a dataset.' @@ -28,25 +24,14 @@ parser.add_argument( ) parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.") parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) -parser.add_argument("--save_npy", type=bool, help="flag to set cuda.", default=False) args = parser.parse_args() - -c = load_config(args.config_path) c_dataset = load_config(args.config_dataset_path) -ap = AudioProcessor(**c["audio"]) - train_files, dev_files = load_meta_data(c_dataset.datasets, eval_split=True, ignore_generated_eval=True) - wav_files = train_files + dev_files -# define Encoder model -model = setup_model(c) -model.load_state_dict(torch.load(args.model_path)["model"]) -model.eval() -if args.use_cuda: - model.cuda() +speaker_manager = SpeakerManager(encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda) # compute speaker embeddings speaker_mapping = {} @@ -57,36 +42,24 @@ for idx, wav_file in enumerate(tqdm(wav_files)): else: speaker_name = None - mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T - mel_spec = torch.FloatTensor(mel_spec[None, :, :]) - if args.use_cuda: - mel_spec = mel_spec.cuda() - embedd = model.compute_embedding(mel_spec) - embedd = embedd.detach().cpu().numpy() + # extract the embedding + embedd = speaker_manager.compute_x_vector_from_clip(wav_file) # create speaker_mapping if target dataset is defined wav_file_name = os.path.basename(wav_file) speaker_mapping[wav_file_name] = {} speaker_mapping[wav_file_name]["name"] = speaker_name - speaker_mapping[wav_file_name]["embedding"] = embedd.flatten().tolist() + speaker_mapping[wav_file_name]["embedding"] = embedd if speaker_mapping: # save speaker_mapping if target dataset is defined - if '.json' not in args.output_path and '.npy' not in args.output_path: - + if '.json' not in args.output_path: mapping_file_path = os.path.join(args.output_path, "speakers.json") - mapping_npy_file_path = os.path.join(args.output_path, "speakers.npy") else: - mapping_file_path = args.output_path.replace(".npy", ".json") - mapping_npy_file_path = mapping_file_path.replace(".json", ".npy") + mapping_file_path = args.output_path os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True) - if args.save_npy: - np.save(mapping_npy_file_path, speaker_mapping) - print("Speaker embeddings saved at:", mapping_npy_file_path) - - speaker_manager = SpeakerManager() # pylint: disable=W0212 speaker_manager._save_json(mapping_file_path, speaker_mapping) print("Speaker embeddings saved at:", mapping_file_path) diff --git a/TTS/speaker_encoder/models/lstm.py b/TTS/speaker_encoder/models/lstm.py index fadada70..21439d6b 100644 --- a/TTS/speaker_encoder/models/lstm.py +++ b/TTS/speaker_encoder/models/lstm.py @@ -119,9 +119,11 @@ class LSTMSpeakerEncoder(nn.Module): return embed / num_iters # pylint: disable=unused-argument, redefined-builtin - def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False): + def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): state = torch.load(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) + if use_cuda: + self.cuda() if eval: self.eval() assert not self.training diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index ce86b01f..29f3ae61 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -199,3 +199,12 @@ class ResNetSpeakerEncoder(nn.Module): embeddings = torch.mean(embeddings, dim=0, keepdim=True) return embeddings + + def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if use_cuda: + self.cuda() + if eval: + self.eval() + assert not self.training \ No newline at end of file diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 84da1f72..1b8c054d 100755 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -133,6 +133,7 @@ class SpeakerManager: speaker_id_file_path: str = "", encoder_model_path: str = "", encoder_config_path: str = "", + use_cuda: bool = False, ): self.x_vectors = None @@ -140,6 +141,7 @@ class SpeakerManager: self.clip_ids = None self.speaker_encoder = None self.speaker_encoder_ap = None + self.use_cuda = use_cuda if x_vectors_file_path: self.load_x_vectors_file(x_vectors_file_path) @@ -215,17 +217,19 @@ class SpeakerManager: def init_speaker_encoder(self, model_path: str, config_path: str) -> None: self.speaker_encoder_config = load_config(config_path) self.speaker_encoder = setup_model(self.speaker_encoder_config) - self.speaker_encoder.load_checkpoint(config_path, model_path, True) + self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda) self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio) # normalize the input audio level and trim silences - self.speaker_encoder_ap.do_sound_norm = True - self.speaker_encoder_ap.do_trim_silence = True + # self.speaker_encoder_ap.do_sound_norm = True + # self.speaker_encoder_ap.do_trim_silence = True def compute_x_vector_from_clip(self, wav_file: Union[str, list]) -> list: def _compute(wav_file: str): waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate) spec = self.speaker_encoder_ap.melspectrogram(waveform) spec = torch.from_numpy(spec.T) + if self.use_cuda: + spec = spec.cuda() spec = spec.unsqueeze(0) x_vector = self.speaker_encoder.compute_embedding(spec) return x_vector @@ -248,6 +252,8 @@ class SpeakerManager: feats = torch.from_numpy(feats) if feats.ndim == 2: feats = feats.unsqueeze(0) + if self.use_cuda: + feats = feats.cuda() return self.speaker_encoder.compute_embedding(feats) def run_umap(self):