diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index 683104d8..612da260 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -128,6 +128,7 @@ class GPT(nn.Module): self.heads = heads self.model_dim = model_dim self.max_conditioning_inputs = max_conditioning_inputs + self.max_gen_mel_tokens = max_mel_tokens - self.max_conditioning_inputs - 2 self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2 self.max_prompt_tokens = max_prompt_tokens @@ -598,7 +599,7 @@ class GPT(nn.Module): bos_token_id=self.start_audio_token, pad_token_id=self.stop_audio_token, eos_token_id=self.stop_audio_token, - max_length=self.max_mel_tokens, + max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1], **hf_generate_kwargs, ) if "return_dict_in_generate" in hf_generate_kwargs: @@ -611,7 +612,7 @@ class GPT(nn.Module): bos_token_id=self.start_audio_token, pad_token_id=self.stop_audio_token, eos_token_id=self.stop_audio_token, - max_length=self.max_mel_tokens, + max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1], do_stream=True, **hf_generate_kwargs, )