diff --git a/README.md b/README.md index ece2a20..a44887b 100644 --- a/README.md +++ b/README.md @@ -21,14 +21,20 @@ Bark is a transformer-based text-to-audio model created by [Suno](https://suno.a ## 🤖 Usage ```python -from bark import SAMPLE_RATE, generate_audio +from bark import SAMPLE_RATE, generate_audio, preload_models from IPython.display import Audio +# download and load all models +preload_models() + +# generate audio from text text_prompt = """ Hello, my name is Suno. And, uh — and I like pizza. [laughs] But I also have other interests such as playing tic tac toe. """ audio_array = generate_audio(text_prompt) + +# play text in notebook Audio(audio_array, rate=SAMPLE_RATE) ``` diff --git a/bark/generation.py b/bark/generation.py index 4e860d8..4aa805e 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -1,4 +1,5 @@ import contextlib +import gc import hashlib import os import re @@ -84,36 +85,33 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" -if USE_SMALL_MODELS: - REMOTE_MODEL_PATHS = { - "text": { - "path": os.path.join(REMOTE_BASE_URL, "text.pt"), - "checksum": "b3e42bcbab23b688355cd44128c4cdd3", - }, - "coarse": { - "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), - "checksum": "5fe964825e3b0321f9d5f3857b89194d", - }, - "fine": { - "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), - "checksum": "5428d1befe05be2ba32195496e58dc90", - }, - } -else: - REMOTE_MODEL_PATHS = { - "text": { - "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), - "checksum": "54afa89d65e318d4f5f80e8e8799026a", - }, - "coarse": { - "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), - "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", - }, - "fine": { - "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), - "checksum": "59d184ed44e3650774a2f0503a48a97b", - }, - } + +REMOTE_MODEL_PATHS = { + "text_small": { + "path": os.path.join(REMOTE_BASE_URL, "text.pt"), + "checksum": "b3e42bcbab23b688355cd44128c4cdd3", + }, + "coarse_small": { + "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), + "checksum": "5fe964825e3b0321f9d5f3857b89194d", + }, + "fine_small": { + "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), + "checksum": "5428d1befe05be2ba32195496e58dc90", + }, + "text": { + "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), + "checksum": "54afa89d65e318d4f5f80e8e8799026a", + }, + "coarse": { + "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), + "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", + }, + "fine": { + "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), + "checksum": "59d184ed44e3650774a2f0503a48a97b", + }, +} if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): @@ -137,8 +135,9 @@ def _md5(fname): return hash_md5.hexdigest() -def _get_ckpt_path(model_type): - model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"]) +def _get_ckpt_path(model_type, use_small=False): + model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type + model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"]) return os.path.join(CACHE_DIR, f"{model_name}.pt") @@ -204,9 +203,10 @@ def clean_models(model_key=None): if k in models: del models[k] _clear_cuda_cache() + gc.collect() -def _load_model(ckpt_path, device, model_type="text"): +def _load_model(ckpt_path, device, use_small=False, model_type="text"): if "cuda" not in device: logger.warning("No GPU being used. Careful, inference might be extremely slow!") if model_type == "text": @@ -220,15 +220,17 @@ def _load_model(ckpt_path, device, model_type="text"): ModelClass = FineGPT else: raise NotImplementedError() + model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type + model_info = REMOTE_MODEL_PATHS[model_key] if ( os.path.exists(ckpt_path) and - _md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"] + _md5(ckpt_path) != model_info["checksum"] ): logger.warning(f"found outdated {model_type} model, removing.") os.remove(ckpt_path) if not os.path.exists(ckpt_path): logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") - _download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path) + _download(model_info["path"], ckpt_path) checkpoint = torch.load(ckpt_path, map_location=device) # this is a hack model_args = checkpoint["model_args"] @@ -278,8 +280,8 @@ def _load_codec_model(device): return model -def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"): - _load_model_f = funcy.partial(_load_model, model_type=model_type) +def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"): + _load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small) if model_type not in ("text", "coarse", "fine"): raise NotImplementedError() global models @@ -289,8 +291,7 @@ def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="tex device = "cuda" model_key = str(device) + f"__{model_type}" if model_key not in models or force_reload: - if ckpt_path is None: - ckpt_path = _get_ckpt_path(model_type) + ckpt_path = _get_ckpt_path(model_type, use_small=use_small) clean_models(model_key=model_key) model = _load_model_f(ckpt_path, device) models[model_key] = model @@ -311,17 +312,29 @@ def load_codec_model(use_gpu=True, force_reload=False): return models[model_key] -def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True): +def preload_models( + text_use_gpu=True, + text_use_small=False, + coarse_use_gpu=True, + coarse_use_small=False, + fine_use_gpu=True, + fine_use_small=False, + codec_use_gpu=True, + force_reload=False, +): _ = load_model( - ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True + model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload ) _ = load_model( - ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True + model_type="coarse", + use_gpu=coarse_use_gpu, + use_small=coarse_use_small, + force_reload=force_reload, ) _ = load_model( - ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True + model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload ) - _ = load_codec_model(use_gpu=use_gpu, force_reload=True) + _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) ####