mirror of https://github.com/coqui-ai/TTS.git
download github model releases by model manager
parent
3c961370e7
commit
49771f2541
|
@ -14,6 +14,7 @@
|
||||||
"model_file": "1CFoPDQBnhfBFu2Gc0TBSJn8o-TuNKQn7",
|
"model_file": "1CFoPDQBnhfBFu2Gc0TBSJn8o-TuNKQn7",
|
||||||
"config_file": "1lWSscNfKet1zZSJCNirOn7v9bigUZ8C1",
|
"config_file": "1lWSscNfKet1zZSJCNirOn7v9bigUZ8C1",
|
||||||
"stats_file": "1qevpGRVHPmzfiRBNuugLMX62x1k7B5vK",
|
"stats_file": "1qevpGRVHPmzfiRBNuugLMX62x1k7B5vK",
|
||||||
|
"github_rls_url": null,
|
||||||
"commit": ""
|
"commit": ""
|
||||||
},
|
},
|
||||||
"speedy-speech-wn":{
|
"speedy-speech-wn":{
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gdown
|
import gdown
|
||||||
|
import requests
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
from TTS.utils.generic_utils import get_user_data_dir
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
|
|
||||||
|
@ -71,6 +74,11 @@ class ModelManager(object):
|
||||||
'type/language/dataset/model'
|
'type/language/dataset/model'
|
||||||
e.g. 'tts_model/en/ljspeech/tacotron'
|
e.g. 'tts_model/en/ljspeech/tacotron'
|
||||||
|
|
||||||
|
Every model must have the following files
|
||||||
|
- *.pth.tar : pytorch model checkpoint file.
|
||||||
|
- config.json : model config file.
|
||||||
|
- scale_stats.npy (if exist): scale values for preprocessing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name (str): model name as explained above.
|
model_name (str): model name as explained above.
|
||||||
|
|
||||||
|
@ -91,11 +99,17 @@ class ModelManager(object):
|
||||||
print(f" > Downloading model to {output_path}")
|
print(f" > Downloading model to {output_path}")
|
||||||
output_stats_path = None
|
output_stats_path = None
|
||||||
# download files to the output path
|
# download files to the output path
|
||||||
self._download_file(model_item['model_file'], output_model_path)
|
if self._check_dict_key(model_item, 'github_rls_url'):
|
||||||
self._download_file(model_item['config_file'], output_config_path)
|
# download from github release
|
||||||
if model_item['stats_file'] is not None and len(model_item['stats_file']) > 1:
|
# TODO: pass output_path
|
||||||
|
self._download_zip_file(model_item['github_rls_url'], output_path)
|
||||||
|
else:
|
||||||
|
# download from gdrive
|
||||||
|
self._download_gdrive_file(model_item['model_file'], output_model_path)
|
||||||
|
self._download_gdrive_file(model_item['config_file'], output_config_path)
|
||||||
|
if self._check_dict_key(model_item, 'scale_stats'):
|
||||||
output_stats_path = os.path.join(output_path, 'scale_stats.npy')
|
output_stats_path = os.path.join(output_path, 'scale_stats.npy')
|
||||||
self._download_file(model_item['stats_file'], output_stats_path)
|
self._download_gdrive_file(model_item['stats_file'], output_stats_path)
|
||||||
# set scale stats path in config.json
|
# set scale stats path in config.json
|
||||||
config_path = output_config_path
|
config_path = output_config_path
|
||||||
config = load_config(config_path)
|
config = load_config(config_path)
|
||||||
|
@ -104,9 +118,25 @@ class ModelManager(object):
|
||||||
json.dump(config, jf)
|
json.dump(config, jf)
|
||||||
return output_model_path, output_config_path
|
return output_model_path, output_config_path
|
||||||
|
|
||||||
def _download_file(self, idx, output):
|
def _download_gdrive_file(self, gdrive_idx, output):
|
||||||
gdown.download(f"{self.url_prefix}{idx}", output=output, quiet=False)
|
"""Download files from GDrive using their file ids"""
|
||||||
|
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)
|
||||||
|
|
||||||
|
def _download_zip_file(self, file_url, output):
|
||||||
|
"""Download the target zip file and extract the files
|
||||||
|
to a folder with the same name as the zip file."""
|
||||||
|
r = requests.get(file_url)
|
||||||
|
z = zipfile.ZipFile(io.BytesIO(r.content))
|
||||||
|
z.extractall(output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_dict_key(my_dict, key):
|
||||||
|
if key in my_dict.keys() and my_dict[key] is not None:
|
||||||
|
if not isinstance(key, str):
|
||||||
|
return True
|
||||||
|
if isinstance(key, str) and len(my_dict[key]) > 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue