diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index aabfa213..683104d8 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -562,7 +562,7 @@ class GPT(nn.Module): def inference(self, cond_latents, text_inputs, **hf_generate_kwargs): self.compute_embeddings(cond_latents, text_inputs) - return self.generate(cond_latents, text_inputs, input_tokens=None, **hf_generate_kwargs) + return self.generate(cond_latents, text_inputs, **hf_generate_kwargs) def compute_embeddings( self, diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index cb0aff75..58f8542b 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -642,6 +642,7 @@ class Xtts(BaseTTS): diffusion_temperature=1.0, decoder_sampler="ddim", decoder="hifigan", + num_beams=1, **hf_generate_kwargs, ): text = text.strip().lower() @@ -673,6 +674,7 @@ class Xtts(BaseTTS): top_k=top_k, temperature=temperature, num_return_sequences=self.gpt_batch_size, + num_beams=num_beams, length_penalty=length_penalty, repetition_penalty=repetition_penalty, output_attentions=False,