mirror of https://github.com/coqui-ai/TTS.git
Make CLI work
parent
0a90359a42
commit
e3c9dab7a3
|
@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
|
|||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||
from TTS.tts.layers.xtts.speaker_manager import SpeakerManager
|
||||
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
@ -379,7 +379,7 @@ class Xtts(BaseTTS):
|
|||
|
||||
return gpt_cond_latents, speaker_embedding
|
||||
|
||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
||||
def synthesize(self, text, config, speaker_wav, language, speaker_id, **kwargs):
|
||||
"""Synthesize speech with the given input text.
|
||||
|
||||
Args:
|
||||
|
@ -394,12 +394,6 @@ class Xtts(BaseTTS):
|
|||
`text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents`
|
||||
as latents used at inference.
|
||||
|
||||
"""
|
||||
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)
|
||||
|
||||
def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
|
||||
"""
|
||||
inference with config
|
||||
"""
|
||||
assert (
|
||||
"zh-cn" if language == "zh" else language in self.config.languages
|
||||
|
@ -411,13 +405,18 @@ class Xtts(BaseTTS):
|
|||
"repetition_penalty": config.repetition_penalty,
|
||||
"top_k": config.top_k,
|
||||
"top_p": config.top_p,
|
||||
}
|
||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
if speaker_id is not None:
|
||||
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
||||
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
||||
settings.update({
|
||||
"gpt_cond_len": config.gpt_cond_len,
|
||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||
"max_ref_len": config.max_ref_len,
|
||||
"sound_norm_refs": config.sound_norm_refs,
|
||||
}
|
||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
return self.full_inference(text, ref_audio_path, language, **settings)
|
||||
})
|
||||
return self.full_inference(text, speaker_wav, language, **settings)
|
||||
|
||||
@torch.inference_mode()
|
||||
def full_inference(
|
||||
|
@ -753,8 +752,9 @@ class Xtts(BaseTTS):
|
|||
|
||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
||||
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers.json")
|
||||
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers_xtts.pth")
|
||||
|
||||
self.language_manager = LanguageManager(config)
|
||||
self.speaker_manager = None
|
||||
if os.path.exists(speaker_file_path):
|
||||
self.speaker_manager = SpeakerManager(speaker_file_path)
|
||||
|
|
|
@ -305,7 +305,7 @@ class Synthesizer(nn.Module):
|
|||
speaker_embedding = None
|
||||
speaker_id = None
|
||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
|
||||
if speaker_name and isinstance(speaker_name, str):
|
||||
if speaker_name and isinstance(speaker_name, str) and not self.tts_config.model == "xtts":
|
||||
if self.tts_config.use_d_vector_file:
|
||||
# get the average speaker embedding from the saved d_vectors.
|
||||
speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
|
||||
|
@ -335,7 +335,9 @@ class Synthesizer(nn.Module):
|
|||
# handle multi-lingual
|
||||
language_id = None
|
||||
if self.tts_languages_file or (
|
||||
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
||||
hasattr(self.tts_model, "language_manager")
|
||||
and self.tts_model.language_manager is not None
|
||||
and not self.tts_config.model == "xtts"
|
||||
):
|
||||
if len(self.tts_model.language_manager.name_to_id) == 1:
|
||||
language_id = list(self.tts_model.language_manager.name_to_id.values())[0]
|
||||
|
@ -366,6 +368,7 @@ class Synthesizer(nn.Module):
|
|||
if (
|
||||
speaker_wav is not None
|
||||
and self.tts_model.speaker_manager is not None
|
||||
and hasattr(self.tts_model.speaker_manager, "encoder_ap")
|
||||
and self.tts_model.speaker_manager.encoder_ap is not None
|
||||
):
|
||||
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
||||
|
|
Loading…
Reference in New Issue