fix max generation length

pull/3208/head
WeberJulian 2023-11-13 13:18:45 +01:00
parent 6f1cba2f81
commit b85536b23f
1 changed files with 3 additions and 2 deletions

View File

@ -128,6 +128,7 @@ class GPT(nn.Module):
self.heads = heads
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.max_gen_mel_tokens = max_mel_tokens - self.max_conditioning_inputs - 2
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
self.max_prompt_tokens = max_prompt_tokens
@ -598,7 +599,7 @@ class GPT(nn.Module):
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens,
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
**hf_generate_kwargs,
)
if "return_dict_in_generate" in hf_generate_kwargs:
@ -611,7 +612,7 @@ class GPT(nn.Module):
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens,
max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1],
do_stream=True,
**hf_generate_kwargs,
)