diff --git a/bark/api.py b/bark/api.py index 300459c..8033dc6 100644 --- a/bark/api.py +++ b/bark/api.py @@ -36,6 +36,7 @@ def semantic_to_waveform( history_prompt: Optional[str] = None, temp: float = 0.7, silent: bool = False, + output_full: bool = False, ): """Generate audio array from semantic input. @@ -44,31 +45,49 @@ def semantic_to_waveform( history_prompt: history choice for audio cloning temp: generation temperature (1.0 more diverse, 0.0 more conservative) silent: disable progress bar + output_full: return full generation to be used as a history prompt Returns: numpy audio array at sample frequency 24khz """ - x_coarse_gen = generate_coarse( + coarse_tokens = generate_coarse( semantic_tokens, history_prompt=history_prompt, temp=temp, silent=silent, ) - x_fine_gen = generate_fine( - x_coarse_gen, + fine_tokens = generate_fine( + coarse_tokens, history_prompt=history_prompt, temp=0.5, ) - audio_arr = codec_decode(x_fine_gen) + audio_arr = codec_decode(fine_tokens) + if output_full: + full_generation = { + "semantic_prompt": semantic_tokens, + "coarse_prompt": coarse_tokens, + "fine_prompt": fine_tokens, + } + return full_generation, audio_arr return audio_arr +def save_as_prompt(filepath, full_generation): + assert(filepath.endswith(".npz")) + assert(isinstance(full_generation, dict)) + assert("semantic_prompt" in full_generation) + assert("coarse_prompt" in full_generation) + assert("fine_prompt" in full_generation) + np.savez(filepath, **full_generation) + + def generate_audio( text: str, history_prompt: Optional[str] = None, text_temp: float = 0.7, waveform_temp: float = 0.7, silent: bool = False, + output_full: bool = False, ): """Generate audio array from input text. @@ -78,14 +97,24 @@ def generate_audio( text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) silent: disable progress bar + output_full: return full generation to be used as a history prompt Returns: numpy audio array at sample frequency 24khz """ - x_semantic = text_to_semantic( + semantic_tokens = text_to_semantic( text, history_prompt=history_prompt, temp=text_temp, silent=silent, ) - audio_arr = semantic_to_waveform( - x_semantic, history_prompt=history_prompt, temp=waveform_temp, silent=silent, + out = semantic_to_waveform( + semantic_tokens, + history_prompt=history_prompt, + temp=waveform_temp, + silent=silent, + output_full=output_full, ) + if output_full: + full_generation, audio_arr = out + return full_generation, audio_arr + else: + audio_arr = out return audio_arr diff --git a/bark/generation.py b/bark/generation.py index fa54388..b5476bc 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -365,10 +365,13 @@ def generate_text_semantic( text = _normalize_whitespace(text) assert len(text.strip()) > 0 if history_prompt is not None: - assert (history_prompt in ALLOWED_PROMPTS) - semantic_history = np.load( - os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") - )["semantic_prompt"] + if history_prompt.endswith(".npz"): + semantic_history = np.load(history_prompt)["semantic_prompt"] + else: + assert (history_prompt in ALLOWED_PROMPTS) + semantic_history = np.load( + os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") + )["semantic_prompt"] assert ( isinstance(semantic_history, np.ndarray) and len(semantic_history.shape) == 1 @@ -509,10 +512,13 @@ def generate_coarse( semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) if history_prompt is not None: - assert (history_prompt in ALLOWED_PROMPTS) - x_history = np.load( - os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") - ) + if history_prompt.endswith(".npz"): + x_history = np.load(history_prompt) + else: + assert (history_prompt in ALLOWED_PROMPTS) + x_history = np.load( + os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") + ) x_semantic_history = x_history["semantic_prompt"] x_coarse_history = x_history["coarse_prompt"] assert ( @@ -652,10 +658,13 @@ def generate_fine( and x_coarse_gen.max() <= CODEBOOK_SIZE - 1 ) if history_prompt is not None: - assert (history_prompt in ALLOWED_PROMPTS) - x_fine_history = np.load( - os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") - )["fine_prompt"] + if history_prompt.endswith(".npz"): + x_fine_history = np.load(history_prompt)["fine_prompt"] + else: + assert (history_prompt in ALLOWED_PROMPTS) + x_fine_history = np.load( + os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") + )["fine_prompt"] assert ( isinstance(x_fine_history, np.ndarray) and len(x_fine_history.shape) == 2