mirror of https://github.com/suno-ai/bark.git
simplify
parent
009ff7cb62
commit
8313b570f4
|
@ -21,14 +21,20 @@ Bark is a transformer-based text-to-audio model created by [Suno](https://suno.a
|
|||
## 🤖 Usage
|
||||
|
||||
```python
|
||||
from bark import SAMPLE_RATE, generate_audio
|
||||
from bark import SAMPLE_RATE, generate_audio, preload_models
|
||||
from IPython.display import Audio
|
||||
|
||||
# download and load all models
|
||||
preload_models()
|
||||
|
||||
# generate audio from text
|
||||
text_prompt = """
|
||||
Hello, my name is Suno. And, uh — and I like pizza. [laughs]
|
||||
But I also have other interests such as playing tic tac toe.
|
||||
"""
|
||||
audio_array = generate_audio(text_prompt)
|
||||
|
||||
# play text in notebook
|
||||
Audio(audio_array, rate=SAMPLE_RATE)
|
||||
```
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
|
@ -84,36 +85,33 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno",
|
|||
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
||||
|
||||
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
|
||||
if USE_SMALL_MODELS:
|
||||
REMOTE_MODEL_PATHS = {
|
||||
"text": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
|
||||
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
|
||||
},
|
||||
"coarse": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
|
||||
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
|
||||
},
|
||||
"fine": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
|
||||
"checksum": "5428d1befe05be2ba32195496e58dc90",
|
||||
},
|
||||
}
|
||||
else:
|
||||
REMOTE_MODEL_PATHS = {
|
||||
"text": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
|
||||
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
||||
},
|
||||
"coarse": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
|
||||
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
||||
},
|
||||
"fine": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
|
||||
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
||||
},
|
||||
}
|
||||
|
||||
REMOTE_MODEL_PATHS = {
|
||||
"text_small": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
|
||||
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
|
||||
},
|
||||
"coarse_small": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
|
||||
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
|
||||
},
|
||||
"fine_small": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
|
||||
"checksum": "5428d1befe05be2ba32195496e58dc90",
|
||||
},
|
||||
"text": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
|
||||
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
||||
},
|
||||
"coarse": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
|
||||
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
||||
},
|
||||
"fine": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
|
||||
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
|
@ -137,8 +135,9 @@ def _md5(fname):
|
|||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def _get_ckpt_path(model_type):
|
||||
model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"])
|
||||
def _get_ckpt_path(model_type, use_small=False):
|
||||
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"])
|
||||
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
||||
|
||||
|
||||
|
@ -204,9 +203,10 @@ def clean_models(model_key=None):
|
|||
if k in models:
|
||||
del models[k]
|
||||
_clear_cuda_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def _load_model(ckpt_path, device, model_type="text"):
|
||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
if "cuda" not in device:
|
||||
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
|
||||
if model_type == "text":
|
||||
|
@ -220,15 +220,17 @@ def _load_model(ckpt_path, device, model_type="text"):
|
|||
ModelClass = FineGPT
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||
model_info = REMOTE_MODEL_PATHS[model_key]
|
||||
if (
|
||||
os.path.exists(ckpt_path) and
|
||||
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
|
||||
_md5(ckpt_path) != model_info["checksum"]
|
||||
):
|
||||
logger.warning(f"found outdated {model_type} model, removing.")
|
||||
os.remove(ckpt_path)
|
||||
if not os.path.exists(ckpt_path):
|
||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
|
||||
_download(model_info["path"], ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
# this is a hack
|
||||
model_args = checkpoint["model_args"]
|
||||
|
@ -278,8 +280,8 @@ def _load_codec_model(device):
|
|||
return model
|
||||
|
||||
|
||||
def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"):
|
||||
_load_model_f = funcy.partial(_load_model, model_type=model_type)
|
||||
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
|
||||
_load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
global models
|
||||
|
@ -289,8 +291,7 @@ def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="tex
|
|||
device = "cuda"
|
||||
model_key = str(device) + f"__{model_type}"
|
||||
if model_key not in models or force_reload:
|
||||
if ckpt_path is None:
|
||||
ckpt_path = _get_ckpt_path(model_type)
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_model_f(ckpt_path, device)
|
||||
models[model_key] = model
|
||||
|
@ -311,17 +312,29 @@ def load_codec_model(use_gpu=True, force_reload=False):
|
|||
return models[model_key]
|
||||
|
||||
|
||||
def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True):
|
||||
def preload_models(
|
||||
text_use_gpu=True,
|
||||
text_use_small=False,
|
||||
coarse_use_gpu=True,
|
||||
coarse_use_small=False,
|
||||
fine_use_gpu=True,
|
||||
fine_use_small=False,
|
||||
codec_use_gpu=True,
|
||||
force_reload=False,
|
||||
):
|
||||
_ = load_model(
|
||||
ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True
|
||||
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
|
||||
)
|
||||
_ = load_model(
|
||||
ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True
|
||||
model_type="coarse",
|
||||
use_gpu=coarse_use_gpu,
|
||||
use_small=coarse_use_small,
|
||||
force_reload=force_reload,
|
||||
)
|
||||
_ = load_model(
|
||||
ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True
|
||||
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
|
||||
)
|
||||
_ = load_codec_model(use_gpu=use_gpu, force_reload=True)
|
||||
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
|
||||
|
||||
|
||||
####
|
||||
|
|
Loading…
Reference in New Issue