mirror of https://github.com/coqui-ai/TTS.git
Model zoo tests (#900)
* Fix VITS model multi-speaker init * Remove gdrive support in model manager * Add model zoo testspull/901/head
parent
aaaa591485
commit
2df0752e73
|
@ -0,0 +1,49 @@
|
|||
name: zoo-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/cache@v1
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y git make
|
||||
sudo apt install -y python3-wheel gcc
|
||||
make system-deps
|
||||
- name: Upgrade pip
|
||||
run: python3 -m pip install --upgrade pip
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make test_zoo
|
3
Makefile
3
Makefile
|
@ -23,6 +23,9 @@ test_aux: ## run aux tests.
|
|||
nosetests tests.aux_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.aux_tests --nologcapture --with-id
|
||||
./run_bash_tests.sh
|
||||
|
||||
test_zoo: ## run zoo tests.
|
||||
nosetests tests.zoo_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.zoo_tests --nologcapture --with-id
|
||||
|
||||
test_failed: ## only run tests failed the last time.
|
||||
nosetests -x --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --failed
|
||||
|
||||
|
|
|
@ -337,7 +337,7 @@ class Vits(BaseTTS):
|
|||
def _init_speaker_embedding(self, config):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if config.speakers_file is not None:
|
||||
self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file_path)
|
||||
self.speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file)
|
||||
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
|
|
|
@ -5,7 +5,6 @@ import zipfile
|
|||
from pathlib import Path
|
||||
from shutil import copyfile, rmtree
|
||||
|
||||
import gdown
|
||||
import requests
|
||||
|
||||
from TTS.config import load_config
|
||||
|
@ -30,7 +29,6 @@ class ModelManager(object):
|
|||
self.output_prefix = get_user_data_dir("tts")
|
||||
else:
|
||||
self.output_prefix = os.path.join(output_prefix, "tts")
|
||||
self.url_prefix = "https://drive.google.com/uc?id="
|
||||
self.models_dict = None
|
||||
if models_file is not None:
|
||||
self.read_models_file(models_file)
|
||||
|
@ -92,8 +90,6 @@ class ModelManager(object):
|
|||
|
||||
Args:
|
||||
model_name (str): model name as explained above.
|
||||
|
||||
TODO: support multi-speaker models
|
||||
"""
|
||||
# fetch model info from the dict
|
||||
model_type, lang, dataset, model = model_name.split("/")
|
||||
|
@ -109,19 +105,8 @@ class ModelManager(object):
|
|||
else:
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
print(f" > Downloading model to {output_path}")
|
||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||
|
||||
# download files to the output path
|
||||
if self._check_dict_key(model_item, "github_rls_url"):
|
||||
# download from github release
|
||||
self._download_zip_file(model_item["github_rls_url"], output_path)
|
||||
else:
|
||||
# download from gdrive
|
||||
self._download_gdrive_file(model_item["model_file"], output_model_path)
|
||||
self._download_gdrive_file(model_item["config_file"], output_config_path)
|
||||
if self._check_dict_key(model_item, "stats_file"):
|
||||
self._download_gdrive_file(model_item["stats_file"], output_stats_path)
|
||||
|
||||
# download from github release
|
||||
self._download_zip_file(model_item["github_rls_url"], output_path)
|
||||
# update paths in the config.json
|
||||
self._update_paths(output_path, output_config_path)
|
||||
return output_model_path, output_config_path, model_item
|
||||
|
@ -168,10 +153,6 @@ class ModelManager(object):
|
|||
config[field_name] = new_path
|
||||
config.save_json(config_path)
|
||||
|
||||
def _download_gdrive_file(self, gdrive_idx, output):
|
||||
"""Download files from GDrive using their file ids"""
|
||||
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)
|
||||
|
||||
@staticmethod
|
||||
def _download_zip_file(file_url, output_folder):
|
||||
"""Download the github releases"""
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
#!/usr/bin/env python3`
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from tests import get_tests_output_path
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
def test_if_all_models_available():
|
||||
"""Check if all the models are downloadable."""
|
||||
print(" > Checking the availability of all the models under the ModelManager.")
|
||||
manager = ModelManager(output_prefix=get_tests_output_path())
|
||||
model_names = manager.list_models()
|
||||
for model_name in model_names:
|
||||
manager.download_model(model_name)
|
||||
print(f" | > OK: {model_name}")
|
||||
|
||||
folders = glob.glob(os.path.join(manager.output_prefix, "*"))
|
||||
assert len(folders) == len(model_names)
|
||||
shutil.rmtree(manager.output_prefix)
|
|
@ -0,0 +1,48 @@
|
|||
#!/usr/bin/env python3`
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from tests import get_tests_output_path, run_cli
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
def test_run_all_models():
|
||||
"""Check if all the models are downloadable and tts models run correctly."""
|
||||
print(" > Run synthesizer with all the models.")
|
||||
download_dir = get_user_data_dir("tts")
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
manager = ModelManager(output_prefix=get_tests_output_path())
|
||||
model_names = manager.list_models()
|
||||
for model_name in model_names:
|
||||
model_path, _, _ = manager.download_model(model_name)
|
||||
if "tts_models" in model_name:
|
||||
local_download_dir = os.path.dirname(model_path)
|
||||
# download and run the model
|
||||
speaker_files = glob.glob(local_download_dir + "/speaker*")
|
||||
if len(speaker_files) > 0:
|
||||
# multi-speaker model
|
||||
if "speaker_ids" in speaker_files[0]:
|
||||
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
|
||||
elif "speakers" in speaker_files[0]:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
|
||||
speaker_id = list(speaker_manager.speaker_ids.keys())[0]
|
||||
run_cli(
|
||||
f"tts --model_name {model_name} "
|
||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}"'
|
||||
)
|
||||
else:
|
||||
# single-speaker model
|
||||
run_cli(f"tts --model_name {model_name} " f'--text "This is an example." --out_path "{output_path}"')
|
||||
# remove downloaded models
|
||||
shutil.rmtree(download_dir)
|
||||
else:
|
||||
# only download the model
|
||||
manager.download_model(model_name)
|
||||
print(f" | > OK: {model_name}")
|
||||
|
||||
folders = glob.glob(os.path.join(manager.output_prefix, "*"))
|
||||
assert len(folders) == len(model_names)
|
||||
shutil.rmtree(manager.output_prefix)
|
Loading…
Reference in New Issue