handle multi speaker and gst in Synthetizer class

pull/441/head
kirianguiller 2021-03-01 15:17:15 +01:00 committed by Eren Gölge
parent cc4efb437b
commit 7dccbfdcd5
5 changed files with 150 additions and 83 deletions

View File

@ -10,6 +10,7 @@ from flask import Flask, render_template, request, send_file
from TTS.utils.io import load_config
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
from TTS.utils.generic_utils import style_wav_uri_to_dict
def create_argparser():
@ -81,13 +82,19 @@ synthesizer = Synthesizer(
args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda
)
use_speaker_embedding = synthesizer.tts_config.get("use_external_speaker_embedding_file", False)
use_gst = synthesizer.tts_config.get("use_gst", False)
app = Flask(__name__)
@app.route("/")
def index():
return render_template("index.html", show_details=args.show_details)
return render_template(
'index.html',
show_details=args.show_details,
use_speaker_embedding=use_speaker_embedding,
use_gst = use_gst
)
@app.route("/details")
def details():
@ -108,9 +115,13 @@ def details():
@app.route("/api/tts", methods=["GET"])
def tts():
text = request.args.get("text")
text = request.args.get('text')
speaker_json_key = request.args.get('speaker', "")
style_wav = request.args.get('style-wav', "")
style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text))
wavs = synthesizer.tts(text)
wavs = synthesizer.tts(text, speaker_json_key=speaker_json_key, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype="audio/wav")

View File

@ -60,6 +60,14 @@
<ul class="list-unstyled">
</ul>
{%if use_speaker_embedding%}
<input id="speaker-json-key" placeholder="speaker json key.." size=45 type="text" name="speaker-json-key">
{%endif%}
{%if use_gst%}
<input value='{"0": 0.1}' id="style-wav" placeholder="style wav (dict or path ot wav).." size=45 type="text" name="style-wav">
{%endif%}
<input id="text" placeholder="Type here..." size=45 type="text" name="text">
<button id="speak-button" name="speak">Speak</button><br/><br/>
{%if show_details%}
@ -73,15 +81,24 @@
<!-- Bootstrap core JavaScript -->
<script>
function getTextValue(textId) {
const container = q(textId)
if (container) {
return container.value
}
return ""
}
function q(selector) {return document.querySelector(selector)}
q('#text').focus()
function do_tts(e) {
text = q('#text').value
const text = q('#text').value
const speakerJsonKey = getTextValue('#speaker-json-key')
const styleWav = getTextValue('#style-wav')
if (text) {
q('#message').textContent = 'Synthesizing...'
q('#speak-button').disabled = true
q('#audio').hidden = true
synthesize(text)
synthesize(text, speakerJsonKey, styleWav)
}
e.preventDefault()
return false
@ -92,8 +109,8 @@
do_tts(e)
}
})
function synthesize(text) {
fetch('/api/tts?text=' + encodeURIComponent(text), {cache: 'no-cache'})
function synthesize(text, speakerJsonKey="", styleWav="") {
fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker=${encodeURIComponent(speakerJsonKey)}&style-wav=${encodeURIComponent(styleWav)}` , {cache: 'no-cache'})
.then(function(res) {
if (!res.ok) throw Error(res.statusText)
return res.blob()

View File

@ -65,30 +65,27 @@ def compute_style_mel(style_wav, ap, cuda=False):
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
if "tacotron" in CONFIG.model.lower():
if CONFIG.use_gst:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings
if 'tacotron' in CONFIG.model.lower():
if not CONFIG.use_gst:
style_mel = None
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
else:
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
elif "glow" in CONFIG.model.lower():
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
elif 'glow' in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, "module"):
# distributed model
postnet_output, _, _, _, alignments, _, _ = model.module.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
inputs, inputs_lengths, g=speaker_embedding_g
)
else:
postnet_output, _, _, _, alignments, _, _ = model.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
inputs, inputs_lengths, g=speaker_embedding_g
)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
@ -99,11 +96,11 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
if hasattr(model, "module"):
# distributed model
postnet_output, alignments = model.module.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
inputs, inputs_lengths, g=speaker_embedding_g
)
else:
postnet_output, alignments = model.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
inputs, inputs_lengths, g=speaker_embedding_g
)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.

View File

@ -1,10 +1,12 @@
import datetime
import glob
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path
from typing import Union
def get_git_branch():
@ -160,6 +162,21 @@ def check_argument(
is_valid = True
assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
elif val_type:
assert (
isinstance(c[name], val_type) or c[name] is None
), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
"""Transform an uri style_wav, in either a string (path to wav file to be use for style transfer)
or a dict (gst tokens/values to be use for styling)
Args:
style_wav (str): uri
Returns:
Union[str, dict]: path to file (str) or gst style (dict)
"""
if os.path.isfile(style_wav) and style_wav.endswith(".wav"):
return style_wav # style_wav is a .wav file located on the server
style_wav = json.loads(style_wav)
return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...}

View File

@ -1,4 +1,5 @@
import time
from typing import List
import numpy as np
import pysbd
@ -17,7 +18,14 @@ from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input, setup_gen
class Synthesizer(object):
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint=None, vocoder_config=None, use_cuda=False):
def __init__(
self,
tts_checkpoint: str,
tts_config_path: str,
vocoder_checkpoint: str = "",
vocoder_config: str = "",
use_cuda: bool = False,
) -> None:
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
model and synthesize speech from the provided text.
@ -27,27 +35,25 @@ class Synthesizer(object):
If you have certain special characters in your text, you need to handle
them before providing the text to Synthesizer.
TODO: handle multi-speaker and GST inference.
Args:
tts_checkpoint (str): path to the tts model file.
tts_config (str): path to the tts config file.
tts_config_path (str): path to the tts config file.
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
"""
self.tts_checkpoint = tts_checkpoint
self.tts_config = tts_config
self.tts_config_path = tts_config_path
self.vocoder_checkpoint = vocoder_checkpoint
self.vocoder_config = vocoder_config
self.use_cuda = use_cuda
self.wavernn = None
self.vocoder_model = None
self.num_speakers = 0
self.tts_speakers = None
self.speaker_embedding_dim = None
self.seg = self.get_segmenter("en")
self.tts_speakers = {}
self.speaker_embedding_dim = 0
self.seg = self._get_segmenter("en")
self.use_cuda = use_cuda
if self.use_cuda:
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
self.load_tts(tts_checkpoint, tts_config, use_cuda)
@ -57,38 +63,40 @@ class Synthesizer(object):
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
@staticmethod
def get_segmenter(lang):
def _get_segmenter(lang: str):
return pysbd.Segmenter(language=lang, clean=True)
def load_speakers(self):
# load speakers
if self.model_config.use_speaker_embedding is not None:
self.tts_speakers = load_speaker_mapping(self.tts_config.tts_speakers_json)
self.num_speakers = len(self.tts_speakers)
else:
self.num_speakers = 0
# set external speaker embedding
if self.tts_config.use_external_speaker_embedding_file:
speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"]
self.speaker_embedding_dim = len(speaker_embedding)
def init_speaker(self, speaker_idx):
# load speakers
def _load_speakers(self) -> None:
print("Loading speakers ...")
self.tts_speakers = load_speaker_mapping(self.tts_config.external_speaker_embedding_file)
self.num_speakers = len(self.tts_speakers)
self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]][
"embedding"
])
def _load_speaker_embedding(self, speaker_json_key: str = ""):
speaker_embedding = None
if hasattr(self, "tts_speakers") and speaker_idx is not None:
assert speaker_idx < len(
self.tts_speakers
), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}"
if self.tts_config.use_external_speaker_embedding_file:
speaker_embedding = self.tts_speakers[speaker_idx]["embedding"]
if self.tts_config.get("use_external_speaker_embedding_file") and not speaker_json_key:
raise ValueError("While 'use_external_speaker_embedding_file', you must pass a 'speaker_json_key'")
if speaker_json_key != "":
assert self.tts_speakers
assert speaker_json_key in self.tts_speakers, f"speaker_json_key is not in self.tts_speakers keys : '{speaker_idx}'"
speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"]
return speaker_embedding
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
def _load_tts(
self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool
) -> None:
# pylint: disable=global-statement
global symbols, phonemes
self.tts_config = load_config(tts_config)
self.tts_config = load_config(tts_config_path)
self.use_phonemes = self.tts_config.use_phonemes
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
@ -100,12 +108,22 @@ class Synthesizer(object):
else:
self.input_size = len(symbols)
self.tts_model = setup_model(self.input_size, num_speakers=self.num_speakers, c=self.tts_config)
self.tts_model.load_checkpoint(tts_config, tts_checkpoint, eval=True)
if self.tts_config.use_speaker_embedding is True:
self._load_speakers()
self.tts_model = setup_model(
self.input_size,
num_speakers=self.num_speakers,
c=self.tts_config,
speaker_embedding_dim=self.speaker_embedding_dim)
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
if use_cuda:
self.tts_model.cuda()
def load_vocoder(self, model_file, model_config, use_cuda):
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
self.vocoder_config = load_config(model_config)
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"])
self.vocoder_model = setup_generator(self.vocoder_config)
@ -113,36 +131,36 @@ class Synthesizer(object):
if use_cuda:
self.vocoder_model.cuda()
def save_wav(self, wav, path):
def _split_into_sentences(self, text) -> List[str]:
return self.seg.segment(text)
def save_wav(self, wav: List[int], path: str) -> None:
wav = np.array(wav)
self.ap.save_wav(wav, path, self.output_sample_rate)
def split_into_sentences(self, text):
return self.seg.segment(text)
def tts(self, text, speaker_idx=None):
def tts(self, text: str, speaker_json_key: str = "", style_wav = None) -> List[int]:
start_time = time.time()
wavs = []
sens = self.split_into_sentences(text)
sens = self._split_into_sentences(text)
print(" > Text splitted to sentences.")
print(sens)
speaker_embedding = self._load_speaker_embedding(speaker_json_key)
speaker_embedding = self.init_speaker(speaker_idx)
use_gl = self.vocoder_model is None
for sen in sens:
# synthesize voice
waveform, _, _, mel_postnet_spec, _, _ = synthesis(
self.tts_model,
sen,
self.tts_config,
self.use_cuda,
self.ap,
speaker_idx,
None,
False,
self.tts_config.enable_eos_bos_chars,
use_gl,
model=self.tts_model,
text=sen,
CONFIG=self.tts_config,
use_cuda=self.use_cuda,
ap=self.ap,
speaker_id=None,
style_wav=style_wav,
truncated=False,
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
use_griffin_lim=use_gl,
speaker_embedding=speaker_embedding,
)
if not use_gl:
@ -152,12 +170,19 @@ class Synthesizer(object):
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
scale_factor = [1, self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate]
scale_factor = [
1,
self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate,
]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
vocoder_input = interpolate_vocoder_input(
scale_factor, vocoder_input
)
else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
vocoder_input = torch.tensor(vocoder_input).unsqueeze(
0
) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))