mirror of https://github.com/coqui-ai/TTS.git
use speaker manager on compute embeddings script
parent
eb84bb2bc8
commit
1c4e806f54
|
@ -1,15 +1,11 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from TTS.config import load_config
|
||||
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Compute embedding vectors for each wav file in a dataset.'
|
||||
|
@ -28,25 +24,14 @@ parser.add_argument(
|
|||
)
|
||||
parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.")
|
||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
|
||||
parser.add_argument("--save_npy", type=bool, help="flag to set cuda.", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
c = load_config(args.config_path)
|
||||
c_dataset = load_config(args.config_dataset_path)
|
||||
|
||||
ap = AudioProcessor(**c["audio"])
|
||||
|
||||
train_files, dev_files = load_meta_data(c_dataset.datasets, eval_split=True, ignore_generated_eval=True)
|
||||
|
||||
wav_files = train_files + dev_files
|
||||
|
||||
# define Encoder model
|
||||
model = setup_model(c)
|
||||
model.load_state_dict(torch.load(args.model_path)["model"])
|
||||
model.eval()
|
||||
if args.use_cuda:
|
||||
model.cuda()
|
||||
speaker_manager = SpeakerManager(encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda)
|
||||
|
||||
# compute speaker embeddings
|
||||
speaker_mapping = {}
|
||||
|
@ -57,36 +42,24 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
|||
else:
|
||||
speaker_name = None
|
||||
|
||||
mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T
|
||||
mel_spec = torch.FloatTensor(mel_spec[None, :, :])
|
||||
if args.use_cuda:
|
||||
mel_spec = mel_spec.cuda()
|
||||
embedd = model.compute_embedding(mel_spec)
|
||||
embedd = embedd.detach().cpu().numpy()
|
||||
# extract the embedding
|
||||
embedd = speaker_manager.compute_x_vector_from_clip(wav_file)
|
||||
|
||||
# create speaker_mapping if target dataset is defined
|
||||
wav_file_name = os.path.basename(wav_file)
|
||||
speaker_mapping[wav_file_name] = {}
|
||||
speaker_mapping[wav_file_name]["name"] = speaker_name
|
||||
speaker_mapping[wav_file_name]["embedding"] = embedd.flatten().tolist()
|
||||
speaker_mapping[wav_file_name]["embedding"] = embedd
|
||||
|
||||
if speaker_mapping:
|
||||
# save speaker_mapping if target dataset is defined
|
||||
if '.json' not in args.output_path and '.npy' not in args.output_path:
|
||||
|
||||
if '.json' not in args.output_path:
|
||||
mapping_file_path = os.path.join(args.output_path, "speakers.json")
|
||||
mapping_npy_file_path = os.path.join(args.output_path, "speakers.npy")
|
||||
else:
|
||||
mapping_file_path = args.output_path.replace(".npy", ".json")
|
||||
mapping_npy_file_path = mapping_file_path.replace(".json", ".npy")
|
||||
mapping_file_path = args.output_path
|
||||
|
||||
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
|
||||
|
||||
if args.save_npy:
|
||||
np.save(mapping_npy_file_path, speaker_mapping)
|
||||
print("Speaker embeddings saved at:", mapping_npy_file_path)
|
||||
|
||||
speaker_manager = SpeakerManager()
|
||||
# pylint: disable=W0212
|
||||
speaker_manager._save_json(mapping_file_path, speaker_mapping)
|
||||
print("Speaker embeddings saved at:", mapping_file_path)
|
||||
|
|
|
@ -119,9 +119,11 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
return embed / num_iters
|
||||
|
||||
# pylint: disable=unused-argument, redefined-builtin
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False):
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
|
|
@ -199,3 +199,12 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
|
||||
return embeddings
|
||||
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
|
@ -133,6 +133,7 @@ class SpeakerManager:
|
|||
speaker_id_file_path: str = "",
|
||||
encoder_model_path: str = "",
|
||||
encoder_config_path: str = "",
|
||||
use_cuda: bool = False,
|
||||
):
|
||||
|
||||
self.x_vectors = None
|
||||
|
@ -140,6 +141,7 @@ class SpeakerManager:
|
|||
self.clip_ids = None
|
||||
self.speaker_encoder = None
|
||||
self.speaker_encoder_ap = None
|
||||
self.use_cuda = use_cuda
|
||||
|
||||
if x_vectors_file_path:
|
||||
self.load_x_vectors_file(x_vectors_file_path)
|
||||
|
@ -215,17 +217,19 @@ class SpeakerManager:
|
|||
def init_speaker_encoder(self, model_path: str, config_path: str) -> None:
|
||||
self.speaker_encoder_config = load_config(config_path)
|
||||
self.speaker_encoder = setup_model(self.speaker_encoder_config)
|
||||
self.speaker_encoder.load_checkpoint(config_path, model_path, True)
|
||||
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
|
||||
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
|
||||
# normalize the input audio level and trim silences
|
||||
self.speaker_encoder_ap.do_sound_norm = True
|
||||
self.speaker_encoder_ap.do_trim_silence = True
|
||||
# self.speaker_encoder_ap.do_sound_norm = True
|
||||
# self.speaker_encoder_ap.do_trim_silence = True
|
||||
|
||||
def compute_x_vector_from_clip(self, wav_file: Union[str, list]) -> list:
|
||||
def _compute(wav_file: str):
|
||||
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
||||
spec = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||
spec = torch.from_numpy(spec.T)
|
||||
if self.use_cuda:
|
||||
spec = spec.cuda()
|
||||
spec = spec.unsqueeze(0)
|
||||
x_vector = self.speaker_encoder.compute_embedding(spec)
|
||||
return x_vector
|
||||
|
@ -248,6 +252,8 @@ class SpeakerManager:
|
|||
feats = torch.from_numpy(feats)
|
||||
if feats.ndim == 2:
|
||||
feats = feats.unsqueeze(0)
|
||||
if self.use_cuda:
|
||||
feats = feats.cuda()
|
||||
return self.speaker_encoder.compute_embedding(feats)
|
||||
|
||||
def run_umap(self):
|
||||
|
|
Loading…
Reference in New Issue