refactor `SpeakerManager`

pull/506/head
Eren Gölge 2021-05-28 15:46:28 +02:00
parent 421194880d
commit f840268181
1 changed files with 177 additions and 22 deletions

View File

@ -1,6 +1,7 @@
import json import json
import os
import random import random
from typing import Any, List, Union from typing import Any, Dict, List, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -10,6 +11,71 @@ from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
def make_speakers_json_path(out_path):
"""Returns conventional speakers.json location."""
return os.path.join(out_path, "speakers.json")
def load_speaker_mapping(out_path):
"""Loads speaker mapping if already present."""
if os.path.splitext(out_path)[1] == ".json":
json_file = out_path
else:
json_file = make_speakers_json_path(out_path)
with open(json_file) as f:
return json.load(f)
def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present."""
if out_path is not None:
speakers_json_path = make_speakers_json_path(out_path)
with open(speakers_json_path, "w") as f:
json.dump(speaker_mapping, f, indent=4)
def get_speaker_manager(c, args, meta_data_train):
"""Inititalize and return a `SpeakerManager` based on config values"""
speaker_manager = SpeakerManager()
if c.use_speaker_embedding:
speaker_manager.set_speaker_ids_from_data(meta_data_train)
if args.restore_path:
# restoring speaker manager from a previous run.
if c.use_external_speaker_embedding_file:
# restore speaker manager with the embedding file
speakers_file = os.path.dirname(args.restore_path)
if not os.path.exists(speakers_file):
print(
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
)
if not os.path.exists(c.external_speaker_embedding_file):
raise RuntimeError(
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file"
)
speaker_manager.load_x_vectors_file(c.external_speaker_embedding_file)
speaker_manager.set_x_vectors_from_file(speakers_file)
elif not c.use_external_speaker_embedding_file: # restor speaker manager with speaker ID file.
speakers_file = os.path.dirname(args.restore_path)
speaker_ids_from_data = speaker_manager.speaker_ids
speaker_manager.set_speaker_ids_from_file(speakers_file)
assert all(
speaker in speaker_manager.speaker_ids for speaker in speaker_ids_from_data
), " [!] You cannot introduce new speakers to a pre-trained model."
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file:
# new speaker manager with external speaker embeddings.
speaker_manager.set_x_vectors_from_file(c.external_speaker_embedding_file)
elif (
c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file
): # new speaker manager with speaker IDs file.
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
print(
" > Training with {} speakers: {}".format(
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
)
)
return speaker_manager
class SpeakerManager: class SpeakerManager:
"""It manages the multi-speaker setup for 🐸TTS models. It loads the speaker files and parses the information """It manages the multi-speaker setup for 🐸TTS models. It loads the speaker files and parses the information
in a way that you can query. There are 3 different scenarios considered. in a way that you can query. There are 3 different scenarios considered.
@ -64,24 +130,24 @@ class SpeakerManager:
self.speaker_encoder_ap = None self.speaker_encoder_ap = None
if data_items: if data_items:
self.speaker_ids = self.parse_speakers() self.speaker_ids, _ = self.parse_speakers_from_data(self.data_items)
if x_vectors_file_path: if x_vectors_file_path:
self.load_x_vectors_file(x_vectors_file_path) self.set_x_vectors_from_file(x_vectors_file_path)
if speaker_id_file_path: if speaker_id_file_path:
self.load_ids_file(speaker_id_file_path) self.set_speaker_ids_from_file(speaker_id_file_path)
if encoder_model_path and encoder_config_path: if encoder_model_path and encoder_config_path:
self.init_speaker_encoder(encoder_model_path, encoder_config_path) self.init_speaker_encoder(encoder_model_path, encoder_config_path)
@staticmethod @staticmethod
def _load_json(json_file_path: str): def _load_json(json_file_path: str) -> Dict:
with open(json_file_path) as f: with open(json_file_path) as f:
return json.load(f) return json.load(f)
@staticmethod @staticmethod
def _save_json(json_file_path: str, data: dict): def _save_json(json_file_path: str, data: dict) -> None:
with open(json_file_path, "w") as f: with open(json_file_path, "w") as f:
json.dump(data, f, indent=4) json.dump(data, f, indent=4)
@ -91,35 +157,101 @@ class SpeakerManager:
@property @property
def x_vector_dim(self): def x_vector_dim(self):
"""Dimensionality of x_vectors. If x_vectors are not loaded, returns zero."""
if self.x_vectors:
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"]) return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])
return 0
def parse_speakers_from_items(self, items: list): @staticmethod
def parse_speakers_from_data(items: list) -> Tuple[Dict, int]:
"""Parse speaker IDs from data samples retured by `load_meta_data()`.
Args:
items (list): Data sampled returned by `load_meta_data()`.
Returns:
Tuple[Dict, int]: speaker IDs and number of speakers.
"""
speakers = sorted({item[2] for item in items}) speakers = sorted({item[2] for item in items})
self.speaker_ids = {name: i for i, name in enumerate(speakers)} speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(self.speaker_ids) num_speakers = len(speaker_ids)
return self.speaker_ids, num_speakers return speaker_ids, num_speakers
def save_ids_file(self, file_path: str): def set_speaker_ids_from_data(self, items: List) -> None:
self._save_json(file_path, self.speaker_ids) """Set speaker IDs from data samples.
def load_ids_file(self, file_path: str): Args:
items (List): Data sampled returned by `load_meta_data()`.
"""
self.speaker_ids, _ = self.parse_speakers_from_data(items)
def set_speaker_ids_from_file(self, file_path: str) -> None:
"""Set speaker IDs from a file.
Args:
file_path (str): Path to the file.
"""
self.speaker_ids = self._load_json(file_path) self.speaker_ids = self._load_json(file_path)
def save_x_vectors_file(self, file_path: str): def save_speaker_ids_to_file(self, file_path: str) -> None:
"""Save speaker IDs to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.speaker_ids)
def save_x_vectors_to_file(self, file_path: str) -> None:
"""Save x_vectors to a json file.
Args:
file_path (str): Path to the output file.
"""
self._save_json(file_path, self.x_vectors) self._save_json(file_path, self.x_vectors)
def load_x_vectors_file(self, file_path: str): def set_x_vectors_from_file(self, file_path: str) -> None:
"""Load x_vectors from a json file.
Args:
file_path (str): Path to the target json file.
"""
self.x_vectors = self._load_json(file_path) self.x_vectors = self._load_json(file_path)
self.speaker_ids = list(set(sorted(x["name"] for x in self.x_vectors.values()))) self.speaker_ids = list(set(sorted(x["name"] for x in self.x_vectors.values())))
self.clip_ids = list(set(sorted(clip_name for clip_name in self.x_vectors.keys()))) self.clip_ids = list(set(sorted(clip_name for clip_name in self.x_vectors.keys())))
def get_x_vector_by_clip(self, clip_idx: str): def get_x_vector_by_clip(self, clip_idx: str) -> List:
"""Get x_vector by clip ID.
Args:
clip_idx (str): Target clip ID.
Returns:
List: x_vector as a list.
"""
return self.x_vectors[clip_idx]["embedding"] return self.x_vectors[clip_idx]["embedding"]
def get_x_vectors_by_speaker(self, speaker_idx: str): def get_x_vectors_by_speaker(self, speaker_idx: str) -> List[List]:
"""Get all x_vectors of a speaker.
Args:
speaker_idx (str): Target speaker ID.
Returns:
List[List]: all the x_vectors of the given speaker.
"""
return [x["embedding"] for x in self.x_vectors.values() if x["name"] == speaker_idx] return [x["embedding"] for x in self.x_vectors.values() if x["name"] == speaker_idx]
def get_mean_x_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False): def get_mean_x_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.Array:
"""Get mean x_vector of a speaker ID.
Args:
speaker_idx (str): Target speaker ID.
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
randomize (bool, optional): Pick random `num_samples`of x_vectors. Defaults to False.
Returns:
np.Array: Mean x_vector.
"""
x_vectors = self.get_x_vectors_by_speaker(speaker_idx) x_vectors = self.get_x_vectors_by_speaker(speaker_idx)
if num_samples is None: if num_samples is None:
x_vectors = np.stack(x_vectors).mean(0) x_vectors = np.stack(x_vectors).mean(0)
@ -131,13 +263,19 @@ class SpeakerManager:
x_vectors = np.stack(x_vectors[:num_samples]).mean(0) x_vectors = np.stack(x_vectors[:num_samples]).mean(0)
return x_vectors return x_vectors
def get_speakers(self): def get_speakers(self) -> List:
return self.speaker_ids return self.speaker_ids
def get_clips(self): def get_clips(self) -> List:
return sorted(self.x_vectors.keys()) return sorted(self.x_vectors.keys())
def init_speaker_encoder(self, model_path: str, config_path: str) -> None: def init_speaker_encoder(self, model_path: str, config_path: str) -> None:
"""Initialize a speaker encoder model.
Args:
model_path (str): Model file path.
config_path (str): Model config file path.
"""
self.speaker_encoder_config = load_config(config_path) self.speaker_encoder_config = load_config(config_path)
self.speaker_encoder = setup_model(self.speaker_encoder_config) self.speaker_encoder = setup_model(self.speaker_encoder_config)
self.speaker_encoder.load_checkpoint(config_path, model_path, True) self.speaker_encoder.load_checkpoint(config_path, model_path, True)
@ -147,6 +285,15 @@ class SpeakerManager:
self.speaker_encoder_ap.do_trim_silence = True self.speaker_encoder_ap.do_trim_silence = True
def compute_x_vector_from_clip(self, wav_file: Union[str, list]) -> list: def compute_x_vector_from_clip(self, wav_file: Union[str, list]) -> list:
"""Compute a x_vector from a given audio file.
Args:
wav_file (Union[str, list]): Target file path.
Returns:
list: Computed x_vector.
"""
def _compute(wav_file: str): def _compute(wav_file: str):
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate) waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
spec = self.speaker_encoder_ap.melspectrogram(waveform) spec = self.speaker_encoder_ap.melspectrogram(waveform)
@ -168,7 +315,15 @@ class SpeakerManager:
x_vector = _compute(wav_file) x_vector = _compute(wav_file)
return x_vector[0].tolist() return x_vector[0].tolist()
def compute_x_vector(self, feats): def compute_x_vector(self, feats: Union[torch.Tensor, np.Array]) -> List:
"""Compute x_vector from features.
Args:
feats (Union[torch.Tensor, np.Array]): Input features.
Returns:
List: computed x_vector.
"""
if isinstance(feats, np.ndarray): if isinstance(feats, np.ndarray):
feats = torch.from_numpy(feats) feats = torch.from_numpy(feats)
if feats.ndim == 2: if feats.ndim == 2: