use speaker manager on compute embeddings script

pull/581/head
Edresson 2021-06-27 03:35:34 -03:00
parent eb84bb2bc8
commit 1c4e806f54
4 changed files with 27 additions and 37 deletions

View File

@ -1,15 +1,11 @@
import argparse
import os
import torch
import numpy as np
from tqdm import tqdm
from TTS.config import load_config
from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
parser = argparse.ArgumentParser(
description='Compute embedding vectors for each wav file in a dataset.'
@ -28,25 +24,14 @@ parser.add_argument(
)
parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.")
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
parser.add_argument("--save_npy", type=bool, help="flag to set cuda.", default=False)
args = parser.parse_args()
c = load_config(args.config_path)
c_dataset = load_config(args.config_dataset_path)
ap = AudioProcessor(**c["audio"])
train_files, dev_files = load_meta_data(c_dataset.datasets, eval_split=True, ignore_generated_eval=True)
wav_files = train_files + dev_files
# define Encoder model
model = setup_model(c)
model.load_state_dict(torch.load(args.model_path)["model"])
model.eval()
if args.use_cuda:
model.cuda()
speaker_manager = SpeakerManager(encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda)
# compute speaker embeddings
speaker_mapping = {}
@ -57,36 +42,24 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
else:
speaker_name = None
mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T
mel_spec = torch.FloatTensor(mel_spec[None, :, :])
if args.use_cuda:
mel_spec = mel_spec.cuda()
embedd = model.compute_embedding(mel_spec)
embedd = embedd.detach().cpu().numpy()
# extract the embedding
embedd = speaker_manager.compute_x_vector_from_clip(wav_file)
# create speaker_mapping if target dataset is defined
wav_file_name = os.path.basename(wav_file)
speaker_mapping[wav_file_name] = {}
speaker_mapping[wav_file_name]["name"] = speaker_name
speaker_mapping[wav_file_name]["embedding"] = embedd.flatten().tolist()
speaker_mapping[wav_file_name]["embedding"] = embedd
if speaker_mapping:
# save speaker_mapping if target dataset is defined
if '.json' not in args.output_path and '.npy' not in args.output_path:
if '.json' not in args.output_path:
mapping_file_path = os.path.join(args.output_path, "speakers.json")
mapping_npy_file_path = os.path.join(args.output_path, "speakers.npy")
else:
mapping_file_path = args.output_path.replace(".npy", ".json")
mapping_npy_file_path = mapping_file_path.replace(".json", ".npy")
mapping_file_path = args.output_path
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
if args.save_npy:
np.save(mapping_npy_file_path, speaker_mapping)
print("Speaker embeddings saved at:", mapping_npy_file_path)
speaker_manager = SpeakerManager()
# pylint: disable=W0212
speaker_manager._save_json(mapping_file_path, speaker_mapping)
print("Speaker embeddings saved at:", mapping_file_path)

View File

@ -119,9 +119,11 @@ class LSTMSpeakerEncoder(nn.Module):
return embed / num_iters
# pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False):
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if use_cuda:
self.cuda()
if eval:
self.eval()
assert not self.training

View File

@ -199,3 +199,12 @@ class ResNetSpeakerEncoder(nn.Module):
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if use_cuda:
self.cuda()
if eval:
self.eval()
assert not self.training

View File

@ -133,6 +133,7 @@ class SpeakerManager:
speaker_id_file_path: str = "",
encoder_model_path: str = "",
encoder_config_path: str = "",
use_cuda: bool = False,
):
self.x_vectors = None
@ -140,6 +141,7 @@ class SpeakerManager:
self.clip_ids = None
self.speaker_encoder = None
self.speaker_encoder_ap = None
self.use_cuda = use_cuda
if x_vectors_file_path:
self.load_x_vectors_file(x_vectors_file_path)
@ -215,17 +217,19 @@ class SpeakerManager:
def init_speaker_encoder(self, model_path: str, config_path: str) -> None:
self.speaker_encoder_config = load_config(config_path)
self.speaker_encoder = setup_model(self.speaker_encoder_config)
self.speaker_encoder.load_checkpoint(config_path, model_path, True)
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
# normalize the input audio level and trim silences
self.speaker_encoder_ap.do_sound_norm = True
self.speaker_encoder_ap.do_trim_silence = True
# self.speaker_encoder_ap.do_sound_norm = True
# self.speaker_encoder_ap.do_trim_silence = True
def compute_x_vector_from_clip(self, wav_file: Union[str, list]) -> list:
def _compute(wav_file: str):
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
spec = self.speaker_encoder_ap.melspectrogram(waveform)
spec = torch.from_numpy(spec.T)
if self.use_cuda:
spec = spec.cuda()
spec = spec.unsqueeze(0)
x_vector = self.speaker_encoder.compute_embedding(spec)
return x_vector
@ -248,6 +252,8 @@ class SpeakerManager:
feats = torch.from_numpy(feats)
if feats.ndim == 2:
feats = feats.unsqueeze(0)
if self.use_cuda:
feats = feats.cuda()
return self.speaker_encoder.compute_embedding(feats)
def run_umap(self):