mirror of https://github.com/coqui-ai/TTS.git
use SpeakerManager in Synthesizer
parent
e97126314c
commit
6d0f5e0459
|
@ -6,7 +6,7 @@ import pysbd
|
|||
import torch
|
||||
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.speakers import load_speaker_mapping
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
|
@ -49,7 +49,10 @@ class Synthesizer(object):
|
|||
self.vocoder_checkpoint = vocoder_checkpoint
|
||||
self.vocoder_config = vocoder_config
|
||||
self.use_cuda = use_cuda
|
||||
|
||||
self.tts_model = None
|
||||
self.vocoder_model = None
|
||||
self.speaker_manager = None
|
||||
self.num_speakers = 0
|
||||
self.tts_speakers = {}
|
||||
self.speaker_embedding_dim = 0
|
||||
|
@ -70,25 +73,10 @@ class Synthesizer(object):
|
|||
|
||||
def _load_speakers(self, speaker_file: str) -> None:
|
||||
print("Loading speakers ...")
|
||||
self.tts_speakers = load_speaker_mapping(speaker_file)
|
||||
self.num_speakers = len(self.tts_speakers)
|
||||
self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"])
|
||||
|
||||
def _load_speaker_embedding(self, speaker_json_key: str = ""):
|
||||
|
||||
speaker_embedding = None
|
||||
|
||||
if not speaker_json_key:
|
||||
raise ValueError(" [!] While 'use_external_speaker_embedding_file', you must pass a 'speaker_json_key'")
|
||||
|
||||
if speaker_json_key != "":
|
||||
assert self.tts_speakers
|
||||
assert (
|
||||
speaker_json_key in self.tts_speakers
|
||||
), f" [!] speaker_json_key is not in self.tts_speakers keys : '{speaker_json_key}'"
|
||||
speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"]
|
||||
|
||||
return speaker_embedding
|
||||
self.speaker_manager = SpeakerManager()
|
||||
self.speaker_manager.load_x_vectors_file(self.tts_config.get("external_speaker_embedding_file", speaker_file))
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
self.speaker_embedding_dim = self.speaker_manager.x_vector_dim
|
||||
|
||||
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||
# pylint: disable=global-statement
|
||||
|
@ -108,7 +96,7 @@ class Synthesizer(object):
|
|||
self.input_size = len(symbols)
|
||||
|
||||
if self.tts_config.use_speaker_embedding is True:
|
||||
self._load_speakers(self.tts_config.get("external_speaker_embedding_file", self.tts_speakers_file))
|
||||
self._load_speakers(self.tts_speakers_file)
|
||||
|
||||
self.tts_model = setup_model(
|
||||
self.input_size,
|
||||
|
@ -116,7 +104,6 @@ class Synthesizer(object):
|
|||
c=self.tts_config,
|
||||
speaker_embedding_dim=self.speaker_embedding_dim,
|
||||
)
|
||||
|
||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
@ -136,13 +123,17 @@ class Synthesizer(object):
|
|||
wav = np.array(wav)
|
||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||
|
||||
def tts(self, text: str, speaker_json_key: str = "", style_wav=None) -> List[int]:
|
||||
def tts(self, text: str, speaker_idx: str = "", style_wav=None) -> List[int]:
|
||||
start_time = time.time()
|
||||
wavs = []
|
||||
sens = self._split_into_sentences(text)
|
||||
print(" > Text splitted to sentences.")
|
||||
print(sens)
|
||||
speaker_embedding = self._load_speaker_embedding(speaker_json_key)
|
||||
|
||||
if speaker_idx and isinstance(speaker_idx, str):
|
||||
speaker_embedding = self.speaker_manager.get_x_vectors_by_speaker(speaker_idx)[0]
|
||||
else:
|
||||
speaker_embedding = None
|
||||
|
||||
use_gl = self.vocoder_model is None
|
||||
|
||||
|
|
Loading…
Reference in New Issue