TTS/tests/test_demo_server.py

35 lines
1.4 KiB
Python
Raw Normal View History

2019-07-22 13:46:26 +00:00
import os
import unittest
import torch as T
from TTS.server.synthesizer import Synthesizer
from TTS.tests import get_tests_input_path, get_tests_output_path
2020-03-01 18:47:08 +00:00
from TTS.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.generic_utils import setup_model
from TTS.utils.io import load_config, save_checkpoint
2019-07-22 13:46:26 +00:00
class DemoServerTest(unittest.TestCase):
2020-03-02 18:33:13 +00:00
# pylint: disable=R0201
2019-07-22 13:46:26 +00:00
def _create_random_model(self):
2020-03-02 18:33:13 +00:00
# pylint: disable=global-statement
global symbols, phonemes
2019-07-22 14:12:27 +00:00
config = load_config(os.path.join(get_tests_output_path(), 'dummy_model_config.json'))
if 'characters' in config.keys():
symbols, phonemes = make_symbols(**config.characters)
2020-03-01 18:47:08 +00:00
2019-07-22 13:46:26 +00:00
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
model = setup_model(num_chars, 0, config)
output_path = os.path.join(get_tests_output_path())
save_checkpoint(model, None, None, None, output_path, 10, 10)
def test_in_out(self):
self._create_random_model()
config = load_config(os.path.join(get_tests_input_path(), 'server_config.json'))
tts_root_path = get_tests_output_path()
config['tts_checkpoint'] = os.path.join(tts_root_path, config['tts_checkpoint'])
config['tts_config'] = os.path.join(tts_root_path, config['tts_config'])
2019-07-22 13:46:26 +00:00
synthesizer = Synthesizer(config)
synthesizer.tts("Better this test works!!")