Add missing global models_devices

pull/146/head
Jairo Correa 2023-04-25 22:42:21 -03:00
parent 8675c23a42
commit dfbe09f00e
1 changed files with 6 additions and 0 deletions

View File

@ -298,6 +298,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError()
global models
global models_devices
device = _grab_best_device(use_gpu=use_gpu)
model_key = f"{model_type}"
if OFFLOAD_CPU:
@ -317,6 +318,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
def load_codec_model(use_gpu=True, force_reload=False):
global models
global models_devices
device = _grab_best_device(use_gpu=use_gpu)
if device == "mps":
# encodec doesn't support mps
@ -421,6 +423,7 @@ def generate_text_semantic(
semantic_history = None
# load models if not yet exist
global models
global models_devices
if "text" not in models:
preload_models()
model_container = models["text"]
@ -616,6 +619,7 @@ def generate_coarse(
x_coarse_history = np.array([], dtype=np.int32)
# load models if not yet exist
global models
global models_devices
if "coarse" not in models:
preload_models()
model = models["coarse"]
@ -755,6 +759,7 @@ def generate_fine(
n_coarse = x_coarse_gen.shape[0]
# load models if not yet exist
global models
global models_devices
if "fine" not in models:
preload_models()
model = models["fine"]
@ -842,6 +847,7 @@ def codec_decode(fine_tokens):
"""Turn quantized audio codes into audio array using encodec."""
# load models if not yet exist
global models
global models_devices
if "codec" not in models:
preload_models()
model = models["codec"]