Compatiblity with PyTorch 2.4+ (weights_only default value changed)

Starting from version 2.4 PyTorch introduces a stricter check for the objects which can be loaded with torch.load(); to use bark successfully with PyTorch>=2.4 the weights_only attribute needs to be set explicitly
pull/619/head
Constantin 2024-12-30 23:31:29 +09:00 committed by GitHub
parent f4f32d4cd4
commit fc7e8e2854
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -209,7 +209,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(model_info["repo_id"], model_info["file_name"])
checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
# this is a hack
model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args: