diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index aa66924a..5b1bc3fd 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -198,13 +198,12 @@ class XttsArgs(Coqpit): Args: gpt_batch_size (int): The size of the auto-regressive batch. enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. - lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False. kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. - vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet. + use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True. For GPT model: ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. @@ -234,12 +233,12 @@ class XttsArgs(Coqpit): gpt_batch_size: int = 1 enable_redaction: bool = False - lazy_load: bool = True kv_cache: bool = True gpt_checkpoint: str = None clvp_checkpoint: str = None decoder_checkpoint: str = None num_chars: int = 255 + use_hifigan: bool = True # XTTS GPT Encoder params tokenizer_file: str = "" @@ -297,7 +296,6 @@ class Xtts(BaseTTS): def __init__(self, config: Coqpit): super().__init__(config, ap=None, tokenizer=None) - self.lazy_load = self.args.lazy_load self.mel_stats_path = None self.config = config self.gpt_checkpoint = self.args.gpt_checkpoint @@ -307,7 +305,6 @@ class Xtts(BaseTTS): self.tokenizer = VoiceBpeTokenizer() self.gpt = None - self.diffusion_decoder = None self.init_models() self.register_buffer("mel_stats", torch.ones(80)) @@ -334,50 +331,38 @@ class Xtts(BaseTTS): stop_audio_token=self.args.gpt_stop_audio_token, ) - self.hifigan_decoder = HifiDecoder( - input_sample_rate=self.args.input_sample_rate, - output_sample_rate=self.args.output_sample_rate, - output_hop_length=self.args.output_hop_length, - ar_mel_length_compression=self.args.ar_mel_length_compression, - decoder_input_dim=self.args.decoder_input_dim, - d_vector_dim=self.args.d_vector_dim, - cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, - ) - self.diffusion_decoder = DiffusionTts( - model_channels=self.args.diff_model_channels, - num_layers=self.args.diff_num_layers, - in_channels=self.args.diff_in_channels, - out_channels=self.args.diff_out_channels, - in_latent_channels=self.args.diff_in_latent_channels, - in_tokens=self.args.diff_in_tokens, - dropout=self.args.diff_dropout, - use_fp16=self.args.diff_use_fp16, - num_heads=self.args.diff_num_heads, - layer_drop=self.args.diff_layer_drop, - unconditioned_percentage=self.args.diff_unconditioned_percentage, - ) + if self.args.use_hifigan: + self.hifigan_decoder = HifiDecoder( + input_sample_rate=self.args.input_sample_rate, + output_sample_rate=self.args.output_sample_rate, + output_hop_length=self.args.output_hop_length, + ar_mel_length_compression=self.args.ar_mel_length_compression, + decoder_input_dim=self.args.decoder_input_dim, + d_vector_dim=self.args.d_vector_dim, + cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, + ) - self.vocoder = UnivNetGenerator() + else: + self.diffusion_decoder = DiffusionTts( + model_channels=self.args.diff_model_channels, + num_layers=self.args.diff_num_layers, + in_channels=self.args.diff_in_channels, + out_channels=self.args.diff_out_channels, + in_latent_channels=self.args.diff_in_latent_channels, + in_tokens=self.args.diff_in_tokens, + dropout=self.args.diff_dropout, + use_fp16=self.args.diff_use_fp16, + num_heads=self.args.diff_num_heads, + layer_drop=self.args.diff_layer_drop, + unconditioned_percentage=self.args.diff_unconditioned_percentage, + ) + self.vocoder = UnivNetGenerator() @property def device(self): return next(self.parameters()).device - @contextmanager - def lazy_load_model(self, model): - """Context to load a model on demand. - - Args: - model (nn.Module): The model to be loaded. - """ - if self.lazy_load: - yield model - else: - m = model.to(self.device) - yield m - m = model.cpu() - def get_gpt_cond_latents(self, audio_path: str, length: int = 3): """Compute the conditioning latents for the GPT model from the given audio. @@ -411,8 +396,7 @@ class Xtts(BaseTTS): ) diffusion_conds.append(cond_mel) diffusion_conds = torch.stack(diffusion_conds, dim=1) - with self.lazy_load_model(self.diffusion_decoder) as diffusion: - diffusion_latent = diffusion.get_conditioning(diffusion_conds) + diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds) return diffusion_latent def get_speaker_embedding( @@ -429,11 +413,15 @@ class Xtts(BaseTTS): self, audio_path, gpt_cond_len=3, - ): + ): + speaker_embedding = None + diffusion_cond_latents = None + if self.args.use_hifigan: + speaker_embedding = self.get_speaker_embedding(audio_path) + else: + diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path) gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T] - diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path) - speaker_embedding = self.get_speaker_embedding(audio_path) - return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device), speaker_embedding + return gpt_cond_latents, diffusion_cond_latents, speaker_embedding def synthesize(self, text, config, speaker_wav, language, **kwargs): """Synthesize speech with the given input text. @@ -500,7 +488,6 @@ class Xtts(BaseTTS): cond_free_k=2, diffusion_temperature=1.0, decoder_sampler="ddim", - use_hifigan=True, **hf_generate_kwargs, ): """ @@ -579,7 +566,6 @@ class Xtts(BaseTTS): cond_free_k=cond_free_k, diffusion_temperature=diffusion_temperature, decoder_sampler=decoder_sampler, - use_hifigan=use_hifigan, **hf_generate_kwargs, ) @@ -614,7 +600,7 @@ class Xtts(BaseTTS): text_tokens.shape[-1] < self.args.gpt_max_text_tokens ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - if not use_hifigan: + if not self.args.use_hifigan: diffuser = load_discrete_vocoder_diffuser( desired_diffusion_steps=decoder_iterations, cond_free=cond_free, @@ -623,60 +609,55 @@ class Xtts(BaseTTS): ) with torch.no_grad(): - self.gpt = self.gpt.to(self.device) - with self.lazy_load_model(self.gpt) as gpt: - gpt_codes = gpt.generate( - cond_latents=gpt_cond_latent, - text_inputs=text_tokens, - input_tokens=None, - do_sample=do_sample, - top_p=top_p, - top_k=top_k, - temperature=temperature, - num_return_sequences=self.gpt_batch_size, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - output_attentions=False, - **hf_generate_kwargs, - ) - expected_output_len = torch.tensor( - [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device - ) - text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) - gpt_latents = gpt( - text_tokens, - text_len, - gpt_codes, - expected_output_len, - cond_latents=gpt_cond_latent, - return_attentions=False, - return_latent=True, - ) - silence_token = 83 - ctokens = 0 - for k in range(gpt_codes.shape[-1]): - if gpt_codes[0, k] == silence_token: - ctokens += 1 - else: - ctokens = 0 - if ctokens > 8: - gpt_latents = gpt_latents[:, :k] - break + gpt_codes = self.gpt.generate( + cond_latents=gpt_cond_latent, + text_inputs=text_tokens, + input_tokens=None, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=self.gpt_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + **hf_generate_kwargs, + ) + expected_output_len = torch.tensor( + [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device + ) + text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) + gpt_latents = self.gpt( + text_tokens, + text_len, + gpt_codes, + expected_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) + silence_token = 83 + ctokens = 0 + for k in range(gpt_codes.shape[-1]): + if gpt_codes[0, k] == silence_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: + gpt_latents = gpt_latents[:, :k] + break - if use_hifigan: - with self.lazy_load_model(self.hifigan_decoder) as hifigan_decoder: - wav = hifigan_decoder(gpt_latents, g=speaker_embedding) + if self.args.use_hifigan: + wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) else: - with self.lazy_load_model(self.diffusion_decoder) as diffusion: - mel = do_spectrogram_diffusion( - diffusion, - diffuser, - gpt_latents, - diffusion_conditioning, - temperature=diffusion_temperature, - ) - with self.lazy_load_model(self.vocoder) as vocoder: - wav = vocoder.inference(mel) + mel = do_spectrogram_diffusion( + self.diffusion_decoder, + diffuser, + gpt_latents, + diffusion_conditioning, + temperature=diffusion_temperature, + ) + wav = self.vocoder.inference(mel) return {"wav": wav.cpu().numpy().squeeze()} @@ -713,6 +694,7 @@ class Xtts(BaseTTS): # Decoder inference **hf_generate_kwargs, ): + assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." text = f"[{language}]{text.strip().lower()}" text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) @@ -781,7 +763,7 @@ class Xtts(BaseTTS): vocab_path=None, eval=False, strict=True, - use_deepspeed=False + use_deepspeed=False, ): """ Loads a checkpoint from disk and initializes the model's state and tokenizer. @@ -807,14 +789,20 @@ class Xtts(BaseTTS): self.init_models() if eval: self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) - self.load_state_dict(load_fsspec(model_path, map_location=self.device)["model"], strict=strict) + + checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] + ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"] + for key in list(checkpoint.keys()): + if key.split(".")[0] in ignore_keys: + del checkpoint[key] + self.load_state_dict(checkpoint, strict=strict) if eval: + if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval() + if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval() + if hasattr(self, "vocoder"): self.vocoder.eval() self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) self.gpt.eval() - self.diffusion_decoder.eval() - self.vocoder.eval() - self.hifigan_decoder.eval() def train_step(self): raise NotImplementedError("XTTS Training is not implemented")