Download HF models

pull/2685/head
Eren G??lge 2023-06-19 14:14:04 +02:00
parent f59da4dba5
commit 5a31fad502
1 changed files with 27 additions and 6 deletions

View File

@ -245,6 +245,26 @@ class ModelManager(object):
else:
print(" > Model's license - No license information available")
def _download_github_model(self, model_item: Dict, output_path: str):
if isinstance(model_item["github_rls_url"], list):
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
def _download_hf_model(self, model_item:Dict, output_path: str):
if isinstance(model_item["hf_url"], list):
self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
else:
self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)
def set_model_url(self, model_item: Dict):
model_item["model_url"] = None
if "github_rls_url" in model_item:
model_item["model_url"] = model_item["github_rls_url"]
elif "hf_url" in model_item:
model_item["model_url"] = model_item["hf_url"]
return model_item
def download_model(self, model_name):
"""Download model files given the full model name.
Model name is in the format
@ -264,6 +284,7 @@ class ModelManager(object):
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
model_item = self.set_model_url(model_item)
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
@ -271,16 +292,16 @@ class ModelManager(object):
else:
os.makedirs(output_path, exist_ok=True)
print(f" > Downloading model to {output_path}")
# download from github release
if isinstance(model_item["github_rls_url"], list):
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
if "github_rls_url" in model_item:
self._download_github_model(model_item, output_path)
elif "hf_url" in model_item:
self._download_hf_model(model_item, output_path)
self.print_model_license(model_item=model_item)
# find downloaded files
output_model_path = output_path
output_config_path = None
if model != "tortoise-v2":
if model not in ["tortoise-v2", "bark"]: # TODO:This is stupid but don't care for now.
output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json
self._update_paths(output_path, output_config_path)