mirror of https://github.com/coqui-ai/TTS.git
Download HF models
parent
f59da4dba5
commit
5a31fad502
|
@ -245,6 +245,26 @@ class ModelManager(object):
|
|||
else:
|
||||
print(" > Model's license - No license information available")
|
||||
|
||||
def _download_github_model(self, model_item: Dict, output_path: str):
|
||||
if isinstance(model_item["github_rls_url"], list):
|
||||
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||
else:
|
||||
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||
|
||||
def _download_hf_model(self, model_item:Dict, output_path: str):
|
||||
if isinstance(model_item["hf_url"], list):
|
||||
self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
|
||||
else:
|
||||
self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)
|
||||
|
||||
def set_model_url(self, model_item: Dict):
|
||||
model_item["model_url"] = None
|
||||
if "github_rls_url" in model_item:
|
||||
model_item["model_url"] = model_item["github_rls_url"]
|
||||
elif "hf_url" in model_item:
|
||||
model_item["model_url"] = model_item["hf_url"]
|
||||
return model_item
|
||||
|
||||
def download_model(self, model_name):
|
||||
"""Download model files given the full model name.
|
||||
Model name is in the format
|
||||
|
@ -264,6 +284,7 @@ class ModelManager(object):
|
|||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||
model_item["model_type"] = model_type
|
||||
model_item = self.set_model_url(model_item)
|
||||
# set the model specific output path
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
|
@ -271,16 +292,16 @@ class ModelManager(object):
|
|||
else:
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
print(f" > Downloading model to {output_path}")
|
||||
# download from github release
|
||||
if isinstance(model_item["github_rls_url"], list):
|
||||
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||
else:
|
||||
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||
if "github_rls_url" in model_item:
|
||||
self._download_github_model(model_item, output_path)
|
||||
elif "hf_url" in model_item:
|
||||
self._download_hf_model(model_item, output_path)
|
||||
|
||||
self.print_model_license(model_item=model_item)
|
||||
# find downloaded files
|
||||
output_model_path = output_path
|
||||
output_config_path = None
|
||||
if model != "tortoise-v2":
|
||||
if model not in ["tortoise-v2", "bark"]: # TODO:This is stupid but don't care for now.
|
||||
output_model_path, output_config_path = self._find_files(output_path)
|
||||
# update paths in the config.json
|
||||
self._update_paths(output_path, output_config_path)
|
||||
|
|
Loading…
Reference in New Issue