make kv caching default in inference

pull/62/head
Georg Kucsko 2023-04-22 15:42:30 -04:00
parent 3247106492
commit 009ff7cb62
1 changed files with 6 additions and 7 deletions

View File

@ -10,7 +10,6 @@ def text_to_semantic(
history_prompt: Optional[str] = None,
temp: float = 0.7,
silent: bool = False,
use_kv_caching = False,
):
"""Generate semantic array from text.
@ -28,7 +27,7 @@ def text_to_semantic(
history_prompt=history_prompt,
temp=temp,
silent=silent,
use_kv_caching=use_kv_caching
use_kv_caching=True
)
return x_semantic
@ -39,7 +38,6 @@ def semantic_to_waveform(
temp: float = 0.7,
silent: bool = False,
output_full: bool = False,
use_kv_caching = False
):
"""Generate audio array from semantic input.
@ -58,7 +56,7 @@ def semantic_to_waveform(
history_prompt=history_prompt,
temp=temp,
silent=silent,
use_kv_caching=use_kv_caching
use_kv_caching=True
)
fine_tokens = generate_fine(
coarse_tokens,
@ -92,7 +90,6 @@ def generate_audio(
waveform_temp: float = 0.7,
silent: bool = False,
output_full: bool = False,
use_kv_caching = False
):
"""Generate audio array from input text.
@ -108,7 +105,10 @@ def generate_audio(
numpy audio array at sample frequency 24khz
"""
semantic_tokens = text_to_semantic(
text, history_prompt=history_prompt, temp=text_temp, silent=silent, use_kv_caching=use_kv_caching
text,
history_prompt=history_prompt,
temp=text_temp,
silent=silent,
)
out = semantic_to_waveform(
semantic_tokens,
@ -116,7 +116,6 @@ def generate_audio(
temp=waveform_temp,
silent=silent,
output_full=output_full,
use_kv_caching=use_kv_caching
)
if output_full:
full_generation, audio_arr = out