From b85536b23f2c70488736044a2fea3ec9fd59cff4 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Mon, 13 Nov 2023 13:18:45 +0100 Subject: [PATCH] fix max generation length --- TTS/tts/layers/xtts/gpt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index 683104d8..612da260 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -128,6 +128,7 @@ class GPT(nn.Module): self.heads = heads self.model_dim = model_dim self.max_conditioning_inputs = max_conditioning_inputs + self.max_gen_mel_tokens = max_mel_tokens - self.max_conditioning_inputs - 2 self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2 self.max_prompt_tokens = max_prompt_tokens @@ -598,7 +599,7 @@ class GPT(nn.Module): bos_token_id=self.start_audio_token, pad_token_id=self.stop_audio_token, eos_token_id=self.stop_audio_token, - max_length=self.max_mel_tokens, + max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1], **hf_generate_kwargs, ) if "return_dict_in_generate" in hf_generate_kwargs: @@ -611,7 +612,7 @@ class GPT(nn.Module): bos_token_id=self.start_audio_token, pad_token_id=self.stop_audio_token, eos_token_id=self.stop_audio_token, - max_length=self.max_mel_tokens, + max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1], do_stream=True, **hf_generate_kwargs, )