mirror of https://github.com/coqui-ai/TTS.git
fix max generation length
parent
6f1cba2f81
commit
b85536b23f
|
@ -128,6 +128,7 @@ class GPT(nn.Module):
|
|||
self.heads = heads
|
||||
self.model_dim = model_dim
|
||||
self.max_conditioning_inputs = max_conditioning_inputs
|
||||
self.max_gen_mel_tokens = max_mel_tokens - self.max_conditioning_inputs - 2
|
||||
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs
|
||||
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
|
||||
self.max_prompt_tokens = max_prompt_tokens
|
||||
|
@ -598,7 +599,7 @@ class GPT(nn.Module):
|
|||
bos_token_id=self.start_audio_token,
|
||||
pad_token_id=self.stop_audio_token,
|
||||
eos_token_id=self.stop_audio_token,
|
||||
max_length=self.max_mel_tokens,
|
||||
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
if "return_dict_in_generate" in hf_generate_kwargs:
|
||||
|
@ -611,7 +612,7 @@ class GPT(nn.Module):
|
|||
bos_token_id=self.start_audio_token,
|
||||
pad_token_id=self.stop_audio_token,
|
||||
eos_token_id=self.stop_audio_token,
|
||||
max_length=self.max_mel_tokens,
|
||||
max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1],
|
||||
do_stream=True,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue