initial SpeakerManager implementation

pull/441/head
Eren Gölge 2021-04-21 13:11:46 +02:00
parent 09890c7421
commit ab313814de
1 changed files with 161 additions and 1 deletions

View File

@ -1,5 +1,12 @@
import json
import os
import random
import numpy as np
import torch
from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.utils.io import load_config
def make_speakers_json_path(out_path):
@ -74,10 +81,163 @@ def parse_speakers(c, args, meta_data_train, OUT_PATH):
speaker_embedding_dim = None
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print(" > Training with {} speakers: {}".format(len(speakers), ", ".join(speakers)))
print(" > Training with {} speakers: {}".format(
len(speakers), ", ".join(speakers)))
else:
num_speakers = 0
speaker_embedding_dim = None
speaker_mapping = None
return num_speakers, speaker_embedding_dim, speaker_mapping
class SpeakerManager:
"""It manages the multi-speaker setup for 🐸TTS models. It loads the speaker files and parses the information
in a way that you can query. There are 3 different scenarios considered.
1. Models using speaker embedding layers. The metafile only includes a mapping of speaker names to ids.
2. Models using external embedding vectors (x vectors). The metafile includes a dictionary in the following
format.
```
{
'clip_name.wav':{
'name': 'speakerA',
'embedding'[<x_vector_values>]
},
...
}
```
3. Computing x vectors at inference with the speaker encoder. It loads the speaker encoder model and
computes x vectors for a given instance.
>>> >>> # load audio processor and speaker encoder
>>> ap = AudioProcessor(**config.audio)
>>> manager = SpeakerManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
>>> # load a sample audio and compute embedding
>>> waveform = ap.load_wav(sample_wav_path)
>>> mel = ap.melspectrogram(waveform)
>>> x_vector = manager.compute_x_vector(mel.T)
Args:
x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by the
TTS model. Defaults to "".
encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "".
encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "".
"""
def __init__(
self,
x_vectors_file_path: str = "",
speaker_id_file_path: str = "",
encoder_model_path: str = "",
encoder_config_path: str = "",
):
self.x_vectors = None
self.speaker_ids = None
self.clip_ids = None
self.speaker_encoder = None
if x_vectors_file_path:
self.load_x_vectors_file(x_vectors_file_path)
if speaker_id_file_path:
self.load_ids_file(speaker_id_file_path)
if encoder_model_path and encoder_config_path:
self.init_speaker_encoder(encoder_model_path, encoder_config_path)
@staticmethod
def _load_json(json_file_path: str):
with open(json_file_path) as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict):
with open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
@property
def num_speakers(self):
return len(self.speaker_ids)
@property
def x_vector_dim(self):
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])
def parser_speakers_from_items(self, items: list):
speaker_ids = sorted({item[2] for item in items})
self.speaker_ids = speaker_ids
num_speakers = len(speaker_ids)
return speaker_ids, num_speakers
def save_ids_file(self, file_path: str):
self._save_json(file_path, self.speaker_ids)
def load_ids_file(self, file_path: str):
self.speaker_ids = self._load_json(file_path)
def save_x_vectors_file(self, file_path: str):
self._save_json(file_path, self.x_vectors)
def load_x_vectors_file(self, file_path: str):
self.x_vectors = self._load_json(file_path)
self.speaker_ids = list(
set(sorted(x["name"] for x in self.x_vectors.values())))
self.clip_ids = list(
set(sorted(clip_name for clip_name in self.x_vectors.keys())))
def get_x_vector_by_clip(self, clip_idx: str):
return self.x_vectors[clip_idx]["embedding"]
def get_x_vectors_by_speaker(self, speaker_idx: str):
return [
x["embedding"] for x in self.x_vectors.values()
if x["name"] == speaker_idx
]
def get_mean_x_vector(self,
speaker_idx: str,
num_samples: int = None,
randomize: bool = False):
x_vectors = self.get_x_vectors_by_speaker(speaker_idx)
if num_samples is None:
x_vectors = np.stack(x_vectors).mean(0)
else:
assert len(
x_vectors
) >= num_samples, f" [!] speaker {speaker_idx} has number of samples < {num_samples}"
if randomize:
x_vectors = np.stack(random.choices(x_vectors,
k=num_samples)).mean(0)
else:
x_vectors = np.stack(x_vectors[:num_samples]).mean(0)
return x_vectors
def get_speakers(self):
return self.speaker_ids
def get_clips(self):
return sorted(self.x_vectors.keys())
def init_speaker_encoder(self, model_path: str, config_path: str):
self.speaker_encoder_config = load_config(config_path)
self.speaker_encoder = setup_model(self.speaker_encoder_config)
self.speaker_encoder.load_checkpoint(config_path, model_path, True)
def compute_x_vector(self, feats):
if isinstance(feats, np.ndarray):
feats = torch.from_numpy(feats)
if feats.ndim == 2:
feats = feats.unsqueeze(0)
return self.speaker_encoder.compute_embedding(feats)
def run_umap(self):
# TODO: implement speaker encoder
raise NotImplementedError
def plot_embeddings(self):
# TODO: implement speaker encoder
raise NotImplementedError