diff --git a/bark/generation.py b/bark/generation.py index e4a6b2c..a704820 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -137,7 +137,7 @@ def _grab_best_device(use_gpu=True): def _get_ckpt_path(model_type, use_small=False): key = model_type - if use_small: + if use_small or USE_SMALL_MODELS: key += "_small" return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])