From 9e92adc5ac532a444ebd7d08a3e2581b0a07173e Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 6 Nov 2023 10:53:25 -0300 Subject: [PATCH] Remove unused kwarg and added num_beams=1 as default --- TTS/tts/layers/xtts/gpt.py | 2 +- TTS/tts/models/xtts.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index aabfa213..683104d8 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -562,7 +562,7 @@ class GPT(nn.Module): def inference(self, cond_latents, text_inputs, **hf_generate_kwargs): self.compute_embeddings(cond_latents, text_inputs) - return self.generate(cond_latents, text_inputs, input_tokens=None, **hf_generate_kwargs) + return self.generate(cond_latents, text_inputs, **hf_generate_kwargs) def compute_embeddings( self, diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index cb0aff75..58f8542b 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -642,6 +642,7 @@ class Xtts(BaseTTS): diffusion_temperature=1.0, decoder_sampler="ddim", decoder="hifigan", + num_beams=1, **hf_generate_kwargs, ): text = text.strip().lower() @@ -673,6 +674,7 @@ class Xtts(BaseTTS): top_k=top_k, temperature=temperature, num_return_sequences=self.gpt_batch_size, + num_beams=num_beams, length_penalty=length_penalty, repetition_penalty=repetition_penalty, output_attentions=False,