From 49771f2541d9a0a7fee64de38e3d4548a0734b47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 9 Feb 2021 14:24:14 +0000 Subject: [PATCH] download github model releases by model manager --- TTS/.models.json | 1 + TTS/utils/manage.py | 44 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/TTS/.models.json b/TTS/.models.json index 075861db..4805ddba 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -14,6 +14,7 @@ "model_file": "1CFoPDQBnhfBFu2Gc0TBSJn8o-TuNKQn7", "config_file": "1lWSscNfKet1zZSJCNirOn7v9bigUZ8C1", "stats_file": "1qevpGRVHPmzfiRBNuugLMX62x1k7B5vK", + "github_rls_url": null, "commit": "" }, "speedy-speech-wn":{ diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 524d8dbf..db62acd1 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -1,8 +1,11 @@ +import io import json import os +import zipfile from pathlib import Path import gdown +import requests from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.io import load_config @@ -71,6 +74,11 @@ class ModelManager(object): 'type/language/dataset/model' e.g. 'tts_model/en/ljspeech/tacotron' + Every model must have the following files + - *.pth.tar : pytorch model checkpoint file. + - config.json : model config file. + - scale_stats.npy (if exist): scale values for preprocessing. + Args: model_name (str): model name as explained above. @@ -91,11 +99,17 @@ class ModelManager(object): print(f" > Downloading model to {output_path}") output_stats_path = None # download files to the output path - self._download_file(model_item['model_file'], output_model_path) - self._download_file(model_item['config_file'], output_config_path) - if model_item['stats_file'] is not None and len(model_item['stats_file']) > 1: + if self._check_dict_key(model_item, 'github_rls_url'): + # download from github release + # TODO: pass output_path + self._download_zip_file(model_item['github_rls_url'], output_path) + else: + # download from gdrive + self._download_gdrive_file(model_item['model_file'], output_model_path) + self._download_gdrive_file(model_item['config_file'], output_config_path) + if self._check_dict_key(model_item, 'scale_stats'): output_stats_path = os.path.join(output_path, 'scale_stats.npy') - self._download_file(model_item['stats_file'], output_stats_path) + self._download_gdrive_file(model_item['stats_file'], output_stats_path) # set scale stats path in config.json config_path = output_config_path config = load_config(config_path) @@ -104,9 +118,25 @@ class ModelManager(object): json.dump(config, jf) return output_model_path, output_config_path - def _download_file(self, idx, output): - gdown.download(f"{self.url_prefix}{idx}", output=output, quiet=False) - + def _download_gdrive_file(self, gdrive_idx, output): + """Download files from GDrive using their file ids""" + gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False) + + def _download_zip_file(self, file_url, output): + """Download the target zip file and extract the files + to a folder with the same name as the zip file.""" + r = requests.get(file_url) + z = zipfile.ZipFile(io.BytesIO(r.content)) + z.extractall(output) + + @staticmethod + def _check_dict_key(my_dict, key): + if key in my_dict.keys() and my_dict[key] is not None: + if not isinstance(key, str): + return True + if isinstance(key, str) and len(my_dict[key]) > 0: + return True + return False