Update generation.py

Partially solve OUT OF MEMORY problem by preload modules on demand versus preload all modules at the same time.
After these modifies, run smoothly on my 8GB graphic card.
pull/531/head
asterocean 2024-02-14 13:35:33 -08:00 committed by GitHub
parent 773624d26d
commit c18a120faf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 24 additions and 16 deletions

View File

@ -264,6 +264,8 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
raise NotImplementedError()
global models
global models_devices
gc.collect()
torch.cuda.empty_cache()
device = _grab_best_device(use_gpu=use_gpu)
model_key = f"{model_type}"
if OFFLOAD_CPU:
@ -301,6 +303,7 @@ def load_codec_model(use_gpu=True, force_reload=False):
def preload_models(
preload_type="text",
text_use_gpu=True,
text_use_small=False,
coarse_use_gpu=True,
@ -315,19 +318,24 @@ def preload_models(
text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu
):
logger.warning("No GPU being used. Careful, inference might be very slow!")
_ = load_model(
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
)
_ = load_model(
model_type="coarse",
use_gpu=coarse_use_gpu,
use_small=coarse_use_small,
force_reload=force_reload,
)
_ = load_model(
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
)
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
models.clear()
if preload_type == "text":
_ = load_model(
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
)
elif preload_type == "coarse":
_ = load_model(
model_type="coarse",
use_gpu=coarse_use_gpu,
use_small=coarse_use_small,
force_reload=force_reload,
)
elif preload_type == "fine":
_ = load_model(
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
)
elif preload_type == "codec":
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
####
@ -593,7 +601,7 @@ def generate_coarse(
global models
global models_devices
if "coarse" not in models:
preload_models()
preload_models(preload_type="coarse")
model = models["coarse"]
if OFFLOAD_CPU:
model.to(models_devices["coarse"])
@ -721,7 +729,7 @@ def generate_fine(
global models
global models_devices
if "fine" not in models:
preload_models()
preload_models(preload_type="fine")
model = models["fine"]
if OFFLOAD_CPU:
model.to(models_devices["fine"])
@ -803,7 +811,7 @@ def codec_decode(fine_tokens):
global models
global models_devices
if "codec" not in models:
preload_models()
preload_models(preload_type="codec")
model = models["codec"]
if OFFLOAD_CPU:
model.to(models_devices["codec"])