diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 0d0b9064..dc0c7b68 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -245,6 +245,26 @@ class ModelManager(object): else: print(" > Model's license - No license information available") + def _download_github_model(self, model_item: Dict, output_path: str): + if isinstance(model_item["github_rls_url"], list): + self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) + + def _download_hf_model(self, model_item:Dict, output_path: str): + if isinstance(model_item["hf_url"], list): + self._download_model_files(model_item["hf_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar) + + def set_model_url(self, model_item: Dict): + model_item["model_url"] = None + if "github_rls_url" in model_item: + model_item["model_url"] = model_item["github_rls_url"] + elif "hf_url" in model_item: + model_item["model_url"] = model_item["hf_url"] + return model_item + def download_model(self, model_name): """Download model files given the full model name. Model name is in the format @@ -264,6 +284,7 @@ class ModelManager(object): model_full_name = f"{model_type}--{lang}--{dataset}--{model}" model_item = self.models_dict[model_type][lang][dataset][model] model_item["model_type"] = model_type + model_item = self.set_model_url(model_item) # set the model specific output path output_path = os.path.join(self.output_prefix, model_full_name) if os.path.exists(output_path): @@ -271,16 +292,16 @@ class ModelManager(object): else: os.makedirs(output_path, exist_ok=True) print(f" > Downloading model to {output_path}") - # download from github release - if isinstance(model_item["github_rls_url"], list): - self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) - else: - self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) + if "github_rls_url" in model_item: + self._download_github_model(model_item, output_path) + elif "hf_url" in model_item: + self._download_hf_model(model_item, output_path) + self.print_model_license(model_item=model_item) # find downloaded files output_model_path = output_path output_config_path = None - if model != "tortoise-v2": + if model not in ["tortoise-v2", "bark"]: # TODO:This is stupid but don't care for now. output_model_path, output_config_path = self._find_files(output_path) # update paths in the config.json self._update_paths(output_path, output_config_path)