refactor `SpeakerManager`

pull/506/head
Eren Gölge 2021-05-28 15:46:28 +02:00
parent 421194880d
commit f840268181
1 changed files with 177 additions and 22 deletions

View File

@ -1,6 +1,7 @@
import json
import os
import random
from typing import Any, List, Union
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import torch
@ -10,6 +11,71 @@ from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.utils.audio import AudioProcessor
def make_speakers_json_path(out_path):
"""Returns conventional speakers.json location."""
return os.path.join(out_path, "speakers.json")
def load_speaker_mapping(out_path):
"""Loads speaker mapping if already present."""
if os.path.splitext(out_path)[1] == ".json":
json_file = out_path
else:
json_file = make_speakers_json_path(out_path)
with open(json_file) as f:
return json.load(f)
def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present."""
if out_path is not None:
speakers_json_path = make_speakers_json_path(out_path)
with open(speakers_json_path, "w") as f:
json.dump(speaker_mapping, f, indent=4)
def get_speaker_manager(c, args, meta_data_train):
"""Inititalize and return a `SpeakerManager` based on config values"""
speaker_manager = SpeakerManager()
if c.use_speaker_embedding:
speaker_manager.set_speaker_ids_from_data(meta_data_train)
if args.restore_path:
# restoring speaker manager from a previous run.
if c.use_external_speaker_embedding_file:
# restore speaker manager with the embedding file
speakers_file = os.path.dirname(args.restore_path)
if not os.path.exists(speakers_file):
print(
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
)
if not os.path.exists(c.external_speaker_embedding_file):
raise RuntimeError(
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file"
)
speaker_manager.load_x_vectors_file(c.external_speaker_embedding_file)
speaker_manager.set_x_vectors_from_file(speakers_file)
elif not c.use_external_speaker_embedding_file: # restor speaker manager with speaker ID file.
speakers_file = os.path.dirname(args.restore_path)
speaker_ids_from_data = speaker_manager.speaker_ids
speaker_manager.set_speaker_ids_from_file(speakers_file)
assert all(
speaker in speaker_manager.speaker_ids for speaker in speaker_ids_from_data
), " [!] You cannot introduce new speakers to a pre-trained model."
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file:
# new speaker manager with external speaker embeddings.
speaker_manager.set_x_vectors_from_file(c.external_speaker_embedding_file)
elif (
c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file
): # new speaker manager with speaker IDs file.
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
print(
" > Training with {} speakers: {}".format(
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
)
)
return speaker_manager
class SpeakerManager:
"""It manages the multi-speaker setup for 🐸TTS models. It loads the speaker files and parses the information
in a way that you can query. There are 3 different scenarios considered.
@ -64,24 +130,24 @@ class SpeakerManager:
self.speaker_encoder_ap = None
if data_items:
self.speaker_ids = self.parse_speakers()
self.speaker_ids, _ = self.parse_speakers_from_data(self.data_items)
if x_vectors_file_path:
self.load_x_vectors_file(x_vectors_file_path)
self.set_x_vectors_from_file(x_vectors_file_path)
if speaker_id_file_path:
self.load_ids_file(speaker_id_file_path)
self.set_speaker_ids_from_file(speaker_id_file_path)
if encoder_model_path and encoder_config_path:
self.init_speaker_encoder(encoder_model_path, encoder_config_path)
@staticmethod
def _load_json(json_file_path: str):
def _load_json(json_file_path: str) -> Dict:
with open(json_file_path) as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict):
def _save_json(json_file_path: str, data: dict) -> None:
with open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
@ -91,35 +157,101 @@ class SpeakerManager:
@property
def x_vector_dim(self):
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])
"""Dimensionality of x_vectors. If x_vectors are not loaded, returns zero."""
if self.x_vectors:
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])
return 0
def parse_speakers_from_items(self, items: list):
@staticmethod
def parse_speakers_from_data(items: list) -> Tuple[Dict, int]:
"""Parse speaker IDs from data samples retured by `load_meta_data()`.
Args:
items (list): Data sampled returned by `load_meta_data()`.
Returns:
Tuple[Dict, int]: speaker IDs and number of speakers.
"""
speakers = sorted({item[2] for item in items})
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(self.speaker_ids)
return self.speaker_ids, num_speakers
speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(speaker_ids)
return speaker_ids, num_speakers
def save_ids_file(self, file_path: str):
self._save_json(file_path, self.speaker_ids)
def set_speaker_ids_from_data(self, items: List) -> None:
"""Set speaker IDs from data samples.
def load_ids_file(self, file_path: str):
Args:
items (List): Data sampled returned by `load_meta_data()`.
"""
self.speaker_ids, _ = self.parse_speakers_from_data(items)
def set_speaker_ids_from_file(self, file_path: str) -> None:
"""Set speaker IDs from a file.
Args:
file_path (str): Path to the file.
"""
self.speaker_ids = self._load_json(file_path)
def save_x_vectors_file(self, file_path: str):
def save_speaker_ids_to_file(self, file_path: str) -> None:
"""Save speaker IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.speaker_ids)
def save_x_vectors_to_file(self, file_path: str) -> None:
"""Save x_vectors to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.x_vectors)
def load_x_vectors_file(self, file_path: str):
def set_x_vectors_from_file(self, file_path: str) -> None:
"""Load x_vectors from a json file.
Args:
file_path (str): Path to the target json file.
"""
self.x_vectors = self._load_json(file_path)
self.speaker_ids = list(set(sorted(x["name"] for x in self.x_vectors.values())))
self.clip_ids = list(set(sorted(clip_name for clip_name in self.x_vectors.keys())))
def get_x_vector_by_clip(self, clip_idx: str):
def get_x_vector_by_clip(self, clip_idx: str) -> List:
"""Get x_vector by clip ID.
Args:
clip_idx (str): Target clip ID.
Returns:
List: x_vector as a list.
"""
return self.x_vectors[clip_idx]["embedding"]
def get_x_vectors_by_speaker(self, speaker_idx: str):
def get_x_vectors_by_speaker(self, speaker_idx: str) -> List[List]:
"""Get all x_vectors of a speaker.
Args:
speaker_idx (str): Target speaker ID.
Returns:
List[List]: all the x_vectors of the given speaker.
"""
return [x["embedding"] for x in self.x_vectors.values() if x["name"] == speaker_idx]
def get_mean_x_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False):
def get_mean_x_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.Array:
"""Get mean x_vector of a speaker ID.
Args:
speaker_idx (str): Target speaker ID.
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
randomize (bool, optional): Pick random `num_samples`of x_vectors. Defaults to False.
Returns:
np.Array: Mean x_vector.
"""
x_vectors = self.get_x_vectors_by_speaker(speaker_idx)
if num_samples is None:
x_vectors = np.stack(x_vectors).mean(0)
@ -131,13 +263,19 @@ class SpeakerManager:
x_vectors = np.stack(x_vectors[:num_samples]).mean(0)
return x_vectors
def get_speakers(self):
def get_speakers(self) -> List:
return self.speaker_ids
def get_clips(self):
def get_clips(self) -> List:
return sorted(self.x_vectors.keys())
def init_speaker_encoder(self, model_path: str, config_path: str) -> None:
"""Initialize a speaker encoder model.
Args:
model_path (str): Model file path.
config_path (str): Model config file path.
"""
self.speaker_encoder_config = load_config(config_path)
self.speaker_encoder = setup_model(self.speaker_encoder_config)
self.speaker_encoder.load_checkpoint(config_path, model_path, True)
@ -147,6 +285,15 @@ class SpeakerManager:
self.speaker_encoder_ap.do_trim_silence = True
def compute_x_vector_from_clip(self, wav_file: Union[str, list]) -> list:
"""Compute a x_vector from a given audio file.
Args:
wav_file (Union[str, list]): Target file path.
Returns:
list: Computed x_vector.
"""
def _compute(wav_file: str):
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
spec = self.speaker_encoder_ap.melspectrogram(waveform)
@ -168,7 +315,15 @@ class SpeakerManager:
x_vector = _compute(wav_file)
return x_vector[0].tolist()
def compute_x_vector(self, feats):
def compute_x_vector(self, feats: Union[torch.Tensor, np.Array]) -> List:
"""Compute x_vector from features.
Args:
feats (Union[torch.Tensor, np.Array]): Input features.
Returns:
List: computed x_vector.
"""
if isinstance(feats, np.ndarray):
feats = torch.from_numpy(feats)
if feats.ndim == 2: