Make CLI work

pull/3405/head
WeberJulian 2023-12-11 18:49:18 +01:00
parent 0a90359a42
commit e3c9dab7a3
2 changed files with 17 additions and 14 deletions

View File

@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.speaker_manager import SpeakerManager
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
@ -379,7 +379,7 @@ class Xtts(BaseTTS):
return gpt_cond_latents, speaker_embedding
def synthesize(self, text, config, speaker_wav, language, **kwargs):
def synthesize(self, text, config, speaker_wav, language, speaker_id, **kwargs):
"""Synthesize speech with the given input text.
Args:
@ -394,12 +394,6 @@ class Xtts(BaseTTS):
`text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents`
as latents used at inference.
"""
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)
def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
"""
inference with config
"""
assert (
"zh-cn" if language == "zh" else language in self.config.languages
@ -411,13 +405,18 @@ class Xtts(BaseTTS):
"repetition_penalty": config.repetition_penalty,
"top_k": config.top_k,
"top_p": config.top_p,
}
settings.update(kwargs) # allow overriding of preset settings with kwargs
if speaker_id is not None:
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
settings.update({
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
}
settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.full_inference(text, ref_audio_path, language, **settings)
})
return self.full_inference(text, speaker_wav, language, **settings)
@torch.inference_mode()
def full_inference(
@ -753,8 +752,9 @@ class Xtts(BaseTTS):
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers.json")
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers_xtts.pth")
self.language_manager = LanguageManager(config)
self.speaker_manager = None
if os.path.exists(speaker_file_path):
self.speaker_manager = SpeakerManager(speaker_file_path)

View File

@ -305,7 +305,7 @@ class Synthesizer(nn.Module):
speaker_embedding = None
speaker_id = None
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
if speaker_name and isinstance(speaker_name, str):
if speaker_name and isinstance(speaker_name, str) and not self.tts_config.model == "xtts":
if self.tts_config.use_d_vector_file:
# get the average speaker embedding from the saved d_vectors.
speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
@ -335,7 +335,9 @@ class Synthesizer(nn.Module):
# handle multi-lingual
language_id = None
if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
hasattr(self.tts_model, "language_manager")
and self.tts_model.language_manager is not None
and not self.tts_config.model == "xtts"
):
if len(self.tts_model.language_manager.name_to_id) == 1:
language_id = list(self.tts_model.language_manager.name_to_id.values())[0]
@ -366,6 +368,7 @@ class Synthesizer(nn.Module):
if (
speaker_wav is not None
and self.tts_model.speaker_manager is not None
and hasattr(self.tts_model.speaker_manager, "encoder_ap")
and self.tts_model.speaker_manager.encoder_ap is not None
):
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)