diff --git a/bark/generation.py b/bark/generation.py index 28d963c..4ac165c 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -36,6 +36,9 @@ else: global models models = {} +global models_devices +models_devices = {} + CONTEXT_WINDOW_SIZE = 1024 @@ -84,6 +87,7 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False) +OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False) REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" @@ -294,8 +298,12 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te if model_type not in ("text", "coarse", "fine"): raise NotImplementedError() global models + global models_devices device = _grab_best_device(use_gpu=use_gpu) model_key = f"{model_type}" + if OFFLOAD_CPU: + models_devices[model_key] = device + device = "cpu" if model_key not in models or force_reload: ckpt_path = _get_ckpt_path(model_type, use_small=use_small) clean_models(model_key=model_key) @@ -310,11 +318,15 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te def load_codec_model(use_gpu=True, force_reload=False): global models + global models_devices device = _grab_best_device(use_gpu=use_gpu) if device == "mps": # encodec doesn't support mps device = "cpu" model_key = "codec" + if OFFLOAD_CPU: + models_devices[model_key] = device + device = "cpu" if model_key not in models or force_reload: clean_models(model_key=model_key) model = _load_codec_model(device) @@ -411,12 +423,15 @@ def generate_text_semantic( semantic_history = None # load models if not yet exist global models + global models_devices if "text" not in models: preload_models() model_container = models["text"] model = model_container["model"] tokenizer = model_container["tokenizer"] encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET + if OFFLOAD_CPU: + model.to(models_devices["text"]) device = next(model.parameters()).device if len(encoded_text) > 256: p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1) @@ -514,6 +529,8 @@ def generate_text_semantic( pbar_state = req_pbar_state pbar.close() out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :] + if OFFLOAD_CPU: + model.to("cpu") assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE) _clear_cuda_cache() return out @@ -602,9 +619,12 @@ def generate_coarse( x_coarse_history = np.array([], dtype=np.int32) # load models if not yet exist global models + global models_devices if "coarse" not in models: preload_models() model = models["coarse"] + if OFFLOAD_CPU: + model.to(models_devices["coarse"]) device = next(model.parameters()).device # start loop n_steps = int( @@ -691,6 +711,8 @@ def generate_coarse( n_step += 1 del x_in del x_semantic_in + if OFFLOAD_CPU: + model.to("cpu") gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :] del x_coarse_in assert len(gen_coarse_arr) == n_steps @@ -737,9 +759,12 @@ def generate_fine( n_coarse = x_coarse_gen.shape[0] # load models if not yet exist global models + global models_devices if "fine" not in models: preload_models() model = models["fine"] + if OFFLOAD_CPU: + model.to(models_devices["fine"]) device = next(model.parameters()).device # make input arr in_arr = np.vstack( @@ -808,6 +833,8 @@ def generate_fine( del in_buffer gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T del in_arr + if OFFLOAD_CPU: + model.to("cpu") gen_fine_arr = gen_fine_arr[:, n_history:] if n_remove_from_end > 0: gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end] @@ -820,9 +847,12 @@ def codec_decode(fine_tokens): """Turn quantized audio codes into audio array using encodec.""" # load models if not yet exist global models + global models_devices if "codec" not in models: preload_models() model = models["codec"] + if OFFLOAD_CPU: + model.to(models_devices["codec"]) device = next(model.parameters()).device arr = torch.from_numpy(fine_tokens)[None] arr = arr.to(device) @@ -831,4 +861,6 @@ def codec_decode(fine_tokens): out = model.decoder(emb) audio_arr = out.detach().cpu().numpy().squeeze() del arr, emb, out + if OFFLOAD_CPU: + model.to("cpu") return audio_arr