mirror of https://github.com/coqui-ai/TTS.git
linter fixes
parent
4244096ccb
commit
420901f4c2
|
@ -7,16 +7,15 @@ from TTS.tts.datasets.preprocess import get_preprocessor_by_name
|
|||
|
||||
|
||||
def main():
|
||||
# pylint: disable=bad-continuation
|
||||
parser = argparse.ArgumentParser(description='''Find all the unique characters or phonemes in a dataset.\n\n'''
|
||||
|
||||
'''Target dataset must be defined in TTS.tts.datasets.preprocess\n\n'''\
|
||||
|
||||
'''
|
||||
Example runs:
|
||||
|
||||
python TTS/bin/find_unique_chars.py --dataset ljspeech --meta_file /path/to/LJSpeech/metadata.csv
|
||||
''',
|
||||
formatter_class=RawTextHelpFormatter)
|
||||
''', formatter_class=RawTextHelpFormatter)
|
||||
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
|
|
|
@ -6,15 +6,13 @@ import argparse
|
|||
import glob
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
|
||||
from TTS.tts.utils.generic_utils import check_config_tts
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
|
||||
from TTS.utils.io import (copy_model_files, load_config,
|
||||
save_characters_to_config)
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
|
||||
|
||||
def parse_arguments(argv):
|
||||
|
|
|
@ -122,7 +122,8 @@ class ModelManager(object):
|
|||
"""Download files from GDrive using their file ids"""
|
||||
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)
|
||||
|
||||
def _download_zip_file(self, file_url, output):
|
||||
@staticmethod
|
||||
def _download_zip_file(file_url, output):
|
||||
"""Download the target zip file and extract the files
|
||||
to a folder with the same name as the zip file."""
|
||||
r = requests.get(file_url)
|
||||
|
|
11
hubconf.py
11
hubconf.py
|
@ -1,11 +1,11 @@
|
|||
dependencies = ['torch', 'gdown', 'pysbd', 'phonemizer', 'unidecode'] # apt install espeak
|
||||
dependencies = ['torch', 'gdown', 'pysbd', 'phonemizer', 'unidecode'] # apt install espeak-ng
|
||||
import torch
|
||||
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder_models/en/ljspeech/mulitband-melgan', use_cuda=False):
|
||||
def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name=None, use_cuda=False):
|
||||
"""TTS entry point for PyTorch Hub that provides a Synthesizer object to synthesize speech from a give text.
|
||||
|
||||
Example:
|
||||
|
@ -15,7 +15,7 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder
|
|||
|
||||
Args:
|
||||
model_name (str, optional): One of the model names from .model.json. Defaults to 'tts_models/en/ljspeech/tacotron2-DCA'.
|
||||
vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/mulitband-melgan'.
|
||||
vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/multiband-melgan'.
|
||||
pretrained (bool, optional): [description]. Defaults to True.
|
||||
|
||||
Returns:
|
||||
|
@ -23,8 +23,9 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder
|
|||
"""
|
||||
manager = ModelManager()
|
||||
|
||||
model_path, config_path = manager.download_model(model_name)
|
||||
vocoder_path, vocoder_config_path = manager.download_model(vocoder_name)
|
||||
model_path, config_path, model_item = manager.download_model(model_name)
|
||||
vocoder_name = model_item['default_vocoder'] if vocoder_name is None else vocoder_name
|
||||
vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name)
|
||||
|
||||
# create synthesizer
|
||||
synt = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, use_cuda)
|
||||
|
|
|
@ -21,7 +21,7 @@ class DemoServerTest(unittest.TestCase):
|
|||
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, 0, config)
|
||||
output_path = os.path.join(get_tests_output_path())
|
||||
save_checkpoint(model, None, 10, 10, 1, output_path)
|
||||
save_checkpoint(model, None, 10, 10, 1, output_path, None)
|
||||
|
||||
def test_in_out(self):
|
||||
self._create_random_model()
|
||||
|
|
Loading…
Reference in New Issue