diff --git a/bark/generation.py b/bark/generation.py index 54f9870..9d7bd92 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -264,6 +264,8 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te raise NotImplementedError() global models global models_devices + gc.collect() + torch.cuda.empty_cache() device = _grab_best_device(use_gpu=use_gpu) model_key = f"{model_type}" if OFFLOAD_CPU: @@ -301,6 +303,7 @@ def load_codec_model(use_gpu=True, force_reload=False): def preload_models( + preload_type="text", text_use_gpu=True, text_use_small=False, coarse_use_gpu=True, @@ -315,19 +318,24 @@ def preload_models( text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu ): logger.warning("No GPU being used. Careful, inference might be very slow!") - _ = load_model( - model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload - ) - _ = load_model( - model_type="coarse", - use_gpu=coarse_use_gpu, - use_small=coarse_use_small, - force_reload=force_reload, - ) - _ = load_model( - model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload - ) - _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) + models.clear() + if preload_type == "text": + _ = load_model( + model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload + ) + elif preload_type == "coarse": + _ = load_model( + model_type="coarse", + use_gpu=coarse_use_gpu, + use_small=coarse_use_small, + force_reload=force_reload, + ) + elif preload_type == "fine": + _ = load_model( + model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload + ) + elif preload_type == "codec": + _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) #### @@ -593,7 +601,7 @@ def generate_coarse( global models global models_devices if "coarse" not in models: - preload_models() + preload_models(preload_type="coarse") model = models["coarse"] if OFFLOAD_CPU: model.to(models_devices["coarse"]) @@ -721,7 +729,7 @@ def generate_fine( global models global models_devices if "fine" not in models: - preload_models() + preload_models(preload_type="fine") model = models["fine"] if OFFLOAD_CPU: model.to(models_devices["fine"]) @@ -803,7 +811,7 @@ def codec_decode(fine_tokens): global models global models_devices if "codec" not in models: - preload_models() + preload_models(preload_type="codec") model = models["codec"] if OFFLOAD_CPU: model.to(models_devices["codec"])