diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 955eeb9b..b5c698f3 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -6,7 +6,6 @@ from pathlib import Path from shutil import copyfile, rmtree from typing import Dict, List, Tuple -import fsspec import requests from tqdm import tqdm @@ -321,46 +320,6 @@ class ModelManager(object): return False return True - def create_dir_and_download_model(self, model_name, model_item, output_path): - os.makedirs(output_path, exist_ok=True) - # handle TOS - if not self.tos_agreed(model_item, output_path): - if not self.ask_tos(output_path): - os.rmdir(output_path) - raise Exception(" [!] You must agree to the terms of service to use this model.") - print(f" > Downloading model to {output_path}") - try: - if "fairseq" in model_name: - self.download_fairseq_model(model_name, output_path) - elif "github_rls_url" in model_item: - self._download_github_model(model_item, output_path) - elif "hf_url" in model_item: - self._download_hf_model(model_item, output_path) - - except requests.RequestException as e: - print(f" > Failed to download the model file to {output_path}") - rmtree(output_path) - raise e - self.print_model_license(model_item=model_item) - - def check_if_configs_are_equal(self, model_name, model_item, output_path): - with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f: - config_local = json.load(f) - remote_url = None - for url in model_item["hf_url"]: - if "config.json" in url: - remote_url = url - break - - with fsspec.open(remote_url, "r", encoding="utf-8") as f: - config_remote = json.load(f) - - if not config_local == config_remote: - print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") - self.create_dir_and_download_model(model_name, model_item, output_path) - else: - print(f" > {model_name} is already downloaded.") - def download_model(self, model_name): """Download model files given the full model name. Model name is in the format @@ -379,18 +338,28 @@ class ModelManager(object): # set the model specific output path output_path = os.path.join(self.output_prefix, model_full_name) if os.path.exists(output_path): - # if the configs are different, redownload it - # ToDo: we need a better way to handle it - if "xtts_v1" in model_name: - try: - self.check_if_configs_are_equal(model_name, model_item, output_path) - except: - pass - else: - print(f" > {model_name} is already downloaded.") + print(f" > {model_name} is already downloaded.") else: - self.create_dir_and_download_model(model_name, model_item, output_path) + os.makedirs(output_path, exist_ok=True) + # handle TOS + if not self.tos_agreed(model_item, output_path): + if not self.ask_tos(output_path): + os.rmdir(output_path) + raise Exception(" [!] You must agree to the terms of service to use this model.") + print(f" > Downloading model to {output_path}") + try: + if "fairseq" in model_name: + self.download_fairseq_model(model_name, output_path) + elif "github_rls_url" in model_item: + self._download_github_model(model_item, output_path) + elif "hf_url" in model_item: + self._download_hf_model(model_item, output_path) + except requests.RequestException as e: + print(f" > Failed to download the model file to {output_path}") + rmtree(output_path) + raise e + self.print_model_license(model_item=model_item) # find downloaded files output_model_path = output_path output_config_path = None