update server.py

pull/441/head
Eren Gölge 2021-04-22 12:36:46 +02:00
parent 32e6afc009
commit 10c988ac8c
1 changed files with 66 additions and 24 deletions

View File

@ -1,13 +1,15 @@
#!flask/bin/python #!flask/bin/python
from typing import Union
import argparse import argparse
import io import io
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
import json
from flask import Flask, render_template, request, send_file from flask import Flask, render_template, request, send_file
from TTS.utils.generic_utils import style_wav_uri_to_dict
from TTS.utils.io import load_config from TTS.utils.io import load_config
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer from TTS.utils.synthesizer import Synthesizer
@ -29,19 +31,27 @@ def create_argparser():
parser.add_argument( parser.add_argument(
"--model_name", "--model_name",
type=str, type=str,
default="tts_models/en/ljspeech/speedy-speech-wn", default="tts_models/en/ljspeech/tacotron2-DDC",
help="name of one of the released tts models.", help="Name of one of the pre-trained tts models in format <language>/<dataset>/<model_name>",
) )
parser.add_argument("--vocoder_name", type=str, default=None, help="name of one of the released vocoder models.") parser.add_argument("--vocoder_name", type=str, default=None, help="name of one of the released vocoder models.")
parser.add_argument("--tts_checkpoint", type=str, help="path to custom tts checkpoint file")
parser.add_argument("--tts_config", type=str, help="path to custom tts config.json file") # Args for running custom models
parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.")
parser.add_argument( parser.add_argument(
"--tts_speakers", "--model_path",
type=str, type=str,
help="path to JSON file containing speaker ids, if speaker ids are used in the model", default=None,
help="Path to model file.",
) )
parser.add_argument("--vocoder_config", type=str, default=None, help="path to vocoder config file.") parser.add_argument(
parser.add_argument("--vocoder_checkpoint", type=str, default=None, help="path to vocoder checkpoint file.") "--vocoder_path",
type=str,
help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).",
default=None,
)
parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None)
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
parser.add_argument("--port", type=int, default=5002, help="port to listen on.") parser.add_argument("--port", type=int, default=5002, help="port to listen on.")
parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.") parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.")
parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.") parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.")
@ -60,26 +70,38 @@ if args.list_models:
sys.exit() sys.exit()
# update in-use models to the specified released models. # update in-use models to the specified released models.
if args.model_name is not None: model_path = None
tts_checkpoint_file, tts_config_file, tts_json_dict = manager.download_model(args.model_name) config_path = None
args.vocoder_name = tts_json_dict["default_vocoder"] if args.vocoder_name is None else args.vocoder_name speakers_file_path = None
vocoder_path = None
vocoder_config_path = None
if args.vocoder_name is not None: # CASE1: list pre-trained TTS models
vocoder_checkpoint_file, vocoder_config_file, vocoder_json_dict = manager.download_model(args.vocoder_name) if args.list_models:
manager.list_models()
sys.exit()
# If these were not specified in the CLI args, use default values with embedded model files # CASE2: load pre-trained model paths
if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file): if args.model_name is not None and not args.model_path:
args.tts_checkpoint = tts_checkpoint_file model_path, config_path, model_item = manager.download_model(args.model_name)
if not args.tts_config and os.path.isfile(tts_config_file): args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
args.tts_config = tts_config_file
if not args.vocoder_checkpoint and os.path.isfile(vocoder_checkpoint_file): if args.vocoder_name is not None and not args.vocoder_path:
args.vocoder_checkpoint = vocoder_checkpoint_file vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
if not args.vocoder_config and os.path.isfile(vocoder_config_file):
args.vocoder_config = vocoder_config_file
# CASE3: set custome model paths
if args.model_path is not None:
model_path = args.model_path
config_path = args.config_path
speakers_file_path = args.speakers_file_path
if args.vocoder_path is not None:
vocoder_path = args.vocoder_path
vocoder_config_path = args.vocoder_config_path
# load models
synthesizer = Synthesizer( synthesizer = Synthesizer(
args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, args.use_cuda
) )
use_speaker_embedding = synthesizer.tts_config.get("use_external_speaker_embedding_file", False) use_speaker_embedding = synthesizer.tts_config.get("use_external_speaker_embedding_file", False)
@ -87,6 +109,26 @@ use_gst = synthesizer.tts_config.get("use_gst", False)
app = Flask(__name__) app = Flask(__name__)
def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
"""Transform an uri style_wav, in either a string (path to wav file to be use for style transfer)
or a dict (gst tokens/values to be use for styling)
Args:
style_wav (str): uri
Returns:
Union[str, dict]: path to file (str) or gst style (dict)
"""
if style_wav:
if os.path.isfile(style_wav) and style_wav.endswith(".wav"):
return style_wav # style_wav is a .wav file located on the server
style_wav = json.loads(style_wav)
return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...}
else:
return None
@app.route("/") @app.route("/")
def index(): def index():
return render_template( return render_template(