mirror of https://github.com/coqui-ai/TTS.git
handle multi speaker and gst in Synthetizer class
parent
cc4efb437b
commit
7dccbfdcd5
|
@ -10,6 +10,7 @@ from flask import Flask, render_template, request, send_file
|
|||
from TTS.utils.io import load_config
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
from TTS.utils.generic_utils import style_wav_uri_to_dict
|
||||
|
||||
|
||||
def create_argparser():
|
||||
|
@ -81,13 +82,19 @@ synthesizer = Synthesizer(
|
|||
args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda
|
||||
)
|
||||
|
||||
use_speaker_embedding = synthesizer.tts_config.get("use_external_speaker_embedding_file", False)
|
||||
use_gst = synthesizer.tts_config.get("use_gst", False)
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return render_template("index.html", show_details=args.show_details)
|
||||
|
||||
return render_template(
|
||||
'index.html',
|
||||
show_details=args.show_details,
|
||||
use_speaker_embedding=use_speaker_embedding,
|
||||
use_gst = use_gst
|
||||
)
|
||||
|
||||
@app.route("/details")
|
||||
def details():
|
||||
|
@ -108,9 +115,13 @@ def details():
|
|||
|
||||
@app.route("/api/tts", methods=["GET"])
|
||||
def tts():
|
||||
text = request.args.get("text")
|
||||
text = request.args.get('text')
|
||||
speaker_json_key = request.args.get('speaker', "")
|
||||
style_wav = request.args.get('style-wav', "")
|
||||
|
||||
style_wav = style_wav_uri_to_dict(style_wav)
|
||||
print(" > Model input: {}".format(text))
|
||||
wavs = synthesizer.tts(text)
|
||||
wavs = synthesizer.tts(text, speaker_json_key=speaker_json_key, style_wav=style_wav)
|
||||
out = io.BytesIO()
|
||||
synthesizer.save_wav(wavs, out)
|
||||
return send_file(out, mimetype="audio/wav")
|
||||
|
|
|
@ -60,6 +60,14 @@
|
|||
|
||||
<ul class="list-unstyled">
|
||||
</ul>
|
||||
{%if use_speaker_embedding%}
|
||||
<input id="speaker-json-key" placeholder="speaker json key.." size=45 type="text" name="speaker-json-key">
|
||||
{%endif%}
|
||||
|
||||
{%if use_gst%}
|
||||
<input value='{"0": 0.1}' id="style-wav" placeholder="style wav (dict or path ot wav).." size=45 type="text" name="style-wav">
|
||||
{%endif%}
|
||||
|
||||
<input id="text" placeholder="Type here..." size=45 type="text" name="text">
|
||||
<button id="speak-button" name="speak">Speak</button><br/><br/>
|
||||
{%if show_details%}
|
||||
|
@ -73,15 +81,24 @@
|
|||
|
||||
<!-- Bootstrap core JavaScript -->
|
||||
<script>
|
||||
function getTextValue(textId) {
|
||||
const container = q(textId)
|
||||
if (container) {
|
||||
return container.value
|
||||
}
|
||||
return ""
|
||||
}
|
||||
function q(selector) {return document.querySelector(selector)}
|
||||
q('#text').focus()
|
||||
function do_tts(e) {
|
||||
text = q('#text').value
|
||||
const text = q('#text').value
|
||||
const speakerJsonKey = getTextValue('#speaker-json-key')
|
||||
const styleWav = getTextValue('#style-wav')
|
||||
if (text) {
|
||||
q('#message').textContent = 'Synthesizing...'
|
||||
q('#speak-button').disabled = true
|
||||
q('#audio').hidden = true
|
||||
synthesize(text)
|
||||
synthesize(text, speakerJsonKey, styleWav)
|
||||
}
|
||||
e.preventDefault()
|
||||
return false
|
||||
|
@ -92,8 +109,8 @@
|
|||
do_tts(e)
|
||||
}
|
||||
})
|
||||
function synthesize(text) {
|
||||
fetch('/api/tts?text=' + encodeURIComponent(text), {cache: 'no-cache'})
|
||||
function synthesize(text, speakerJsonKey="", styleWav="") {
|
||||
fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker=${encodeURIComponent(speakerJsonKey)}&style-wav=${encodeURIComponent(styleWav)}` , {cache: 'no-cache'})
|
||||
.then(function(res) {
|
||||
if (!res.ok) throw Error(res.statusText)
|
||||
return res.blob()
|
||||
|
|
|
@ -65,30 +65,27 @@ def compute_style_mel(style_wav, ap, cuda=False):
|
|||
|
||||
|
||||
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
|
||||
if "tacotron" in CONFIG.model.lower():
|
||||
if CONFIG.use_gst:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings
|
||||
if 'tacotron' in CONFIG.model.lower():
|
||||
if not CONFIG.use_gst:
|
||||
style_mel = None
|
||||
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
else:
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
elif "glow" in CONFIG.model.lower():
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
elif 'glow' in CONFIG.model.lower():
|
||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
if hasattr(model, "module"):
|
||||
# distributed model
|
||||
postnet_output, _, _, _, alignments, _, _ = model.module.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
inputs, inputs_lengths, g=speaker_embedding_g
|
||||
)
|
||||
else:
|
||||
postnet_output, _, _, _, alignments, _, _ = model.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
inputs, inputs_lengths, g=speaker_embedding_g
|
||||
)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
|
@ -99,11 +96,11 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
|
|||
if hasattr(model, "module"):
|
||||
# distributed model
|
||||
postnet_output, alignments = model.module.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
inputs, inputs_lengths, g=speaker_embedding_g
|
||||
)
|
||||
else:
|
||||
postnet_output, alignments = model.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
inputs, inputs_lengths, g=speaker_embedding_g
|
||||
)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import datetime
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
|
@ -160,6 +162,21 @@ def check_argument(
|
|||
is_valid = True
|
||||
assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
|
||||
elif val_type:
|
||||
assert (
|
||||
isinstance(c[name], val_type) or c[name] is None
|
||||
), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
|
||||
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
||||
|
||||
|
||||
def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
|
||||
"""Transform an uri style_wav, in either a string (path to wav file to be use for style transfer)
|
||||
or a dict (gst tokens/values to be use for styling)
|
||||
|
||||
Args:
|
||||
style_wav (str): uri
|
||||
|
||||
Returns:
|
||||
Union[str, dict]: path to file (str) or gst style (dict)
|
||||
"""
|
||||
if os.path.isfile(style_wav) and style_wav.endswith(".wav"):
|
||||
return style_wav # style_wav is a .wav file located on the server
|
||||
|
||||
style_wav = json.loads(style_wav)
|
||||
return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import time
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pysbd
|
||||
|
@ -17,7 +18,14 @@ from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input, setup_gen
|
|||
|
||||
|
||||
class Synthesizer(object):
|
||||
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint=None, vocoder_config=None, use_cuda=False):
|
||||
def __init__(
|
||||
self,
|
||||
tts_checkpoint: str,
|
||||
tts_config_path: str,
|
||||
vocoder_checkpoint: str = "",
|
||||
vocoder_config: str = "",
|
||||
use_cuda: bool = False,
|
||||
) -> None:
|
||||
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
||||
model and synthesize speech from the provided text.
|
||||
|
||||
|
@ -27,27 +35,25 @@ class Synthesizer(object):
|
|||
If you have certain special characters in your text, you need to handle
|
||||
them before providing the text to Synthesizer.
|
||||
|
||||
TODO: handle multi-speaker and GST inference.
|
||||
|
||||
Args:
|
||||
tts_checkpoint (str): path to the tts model file.
|
||||
tts_config (str): path to the tts config file.
|
||||
tts_config_path (str): path to the tts config file.
|
||||
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
|
||||
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
|
||||
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
||||
"""
|
||||
self.tts_checkpoint = tts_checkpoint
|
||||
self.tts_config = tts_config
|
||||
self.tts_config_path = tts_config_path
|
||||
self.vocoder_checkpoint = vocoder_checkpoint
|
||||
self.vocoder_config = vocoder_config
|
||||
self.use_cuda = use_cuda
|
||||
self.wavernn = None
|
||||
self.vocoder_model = None
|
||||
self.num_speakers = 0
|
||||
self.tts_speakers = None
|
||||
self.speaker_embedding_dim = None
|
||||
self.seg = self.get_segmenter("en")
|
||||
self.tts_speakers = {}
|
||||
self.speaker_embedding_dim = 0
|
||||
self.seg = self._get_segmenter("en")
|
||||
self.use_cuda = use_cuda
|
||||
|
||||
if self.use_cuda:
|
||||
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
||||
self.load_tts(tts_checkpoint, tts_config, use_cuda)
|
||||
|
@ -57,38 +63,40 @@ class Synthesizer(object):
|
|||
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
||||
|
||||
@staticmethod
|
||||
def get_segmenter(lang):
|
||||
def _get_segmenter(lang: str):
|
||||
return pysbd.Segmenter(language=lang, clean=True)
|
||||
|
||||
def load_speakers(self):
|
||||
# load speakers
|
||||
if self.model_config.use_speaker_embedding is not None:
|
||||
self.tts_speakers = load_speaker_mapping(self.tts_config.tts_speakers_json)
|
||||
self.num_speakers = len(self.tts_speakers)
|
||||
else:
|
||||
self.num_speakers = 0
|
||||
# set external speaker embedding
|
||||
if self.tts_config.use_external_speaker_embedding_file:
|
||||
speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"]
|
||||
self.speaker_embedding_dim = len(speaker_embedding)
|
||||
|
||||
def init_speaker(self, speaker_idx):
|
||||
# load speakers
|
||||
def _load_speakers(self) -> None:
|
||||
print("Loading speakers ...")
|
||||
self.tts_speakers = load_speaker_mapping(self.tts_config.external_speaker_embedding_file)
|
||||
self.num_speakers = len(self.tts_speakers)
|
||||
self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]][
|
||||
"embedding"
|
||||
])
|
||||
|
||||
def _load_speaker_embedding(self, speaker_json_key: str = ""):
|
||||
|
||||
speaker_embedding = None
|
||||
if hasattr(self, "tts_speakers") and speaker_idx is not None:
|
||||
assert speaker_idx < len(
|
||||
self.tts_speakers
|
||||
), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}"
|
||||
if self.tts_config.use_external_speaker_embedding_file:
|
||||
speaker_embedding = self.tts_speakers[speaker_idx]["embedding"]
|
||||
|
||||
if self.tts_config.get("use_external_speaker_embedding_file") and not speaker_json_key:
|
||||
raise ValueError("While 'use_external_speaker_embedding_file', you must pass a 'speaker_json_key'")
|
||||
|
||||
if speaker_json_key != "":
|
||||
assert self.tts_speakers
|
||||
assert speaker_json_key in self.tts_speakers, f"speaker_json_key is not in self.tts_speakers keys : '{speaker_idx}'"
|
||||
speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"]
|
||||
|
||||
return speaker_embedding
|
||||
|
||||
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
|
||||
def _load_tts(
|
||||
self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool
|
||||
) -> None:
|
||||
# pylint: disable=global-statement
|
||||
|
||||
global symbols, phonemes
|
||||
|
||||
self.tts_config = load_config(tts_config)
|
||||
self.tts_config = load_config(tts_config_path)
|
||||
self.use_phonemes = self.tts_config.use_phonemes
|
||||
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
||||
|
||||
|
@ -100,12 +108,22 @@ class Synthesizer(object):
|
|||
else:
|
||||
self.input_size = len(symbols)
|
||||
|
||||
self.tts_model = setup_model(self.input_size, num_speakers=self.num_speakers, c=self.tts_config)
|
||||
self.tts_model.load_checkpoint(tts_config, tts_checkpoint, eval=True)
|
||||
if self.tts_config.use_speaker_embedding is True:
|
||||
self._load_speakers()
|
||||
|
||||
self.tts_model = setup_model(
|
||||
self.input_size,
|
||||
num_speakers=self.num_speakers,
|
||||
c=self.tts_config,
|
||||
speaker_embedding_dim=self.speaker_embedding_dim)
|
||||
|
||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
def load_vocoder(self, model_file, model_config, use_cuda):
|
||||
|
||||
|
||||
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||
self.vocoder_config = load_config(model_config)
|
||||
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"])
|
||||
self.vocoder_model = setup_generator(self.vocoder_config)
|
||||
|
@ -113,36 +131,36 @@ class Synthesizer(object):
|
|||
if use_cuda:
|
||||
self.vocoder_model.cuda()
|
||||
|
||||
def save_wav(self, wav, path):
|
||||
def _split_into_sentences(self, text) -> List[str]:
|
||||
return self.seg.segment(text)
|
||||
|
||||
def save_wav(self, wav: List[int], path: str) -> None:
|
||||
wav = np.array(wav)
|
||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||
|
||||
def split_into_sentences(self, text):
|
||||
return self.seg.segment(text)
|
||||
|
||||
def tts(self, text, speaker_idx=None):
|
||||
def tts(self, text: str, speaker_json_key: str = "", style_wav = None) -> List[int]:
|
||||
start_time = time.time()
|
||||
wavs = []
|
||||
sens = self.split_into_sentences(text)
|
||||
sens = self._split_into_sentences(text)
|
||||
print(" > Text splitted to sentences.")
|
||||
print(sens)
|
||||
speaker_embedding = self._load_speaker_embedding(speaker_json_key)
|
||||
|
||||
speaker_embedding = self.init_speaker(speaker_idx)
|
||||
use_gl = self.vocoder_model is None
|
||||
|
||||
for sen in sens:
|
||||
# synthesize voice
|
||||
waveform, _, _, mel_postnet_spec, _, _ = synthesis(
|
||||
self.tts_model,
|
||||
sen,
|
||||
self.tts_config,
|
||||
self.use_cuda,
|
||||
self.ap,
|
||||
speaker_idx,
|
||||
None,
|
||||
False,
|
||||
self.tts_config.enable_eos_bos_chars,
|
||||
use_gl,
|
||||
model=self.tts_model,
|
||||
text=sen,
|
||||
CONFIG=self.tts_config,
|
||||
use_cuda=self.use_cuda,
|
||||
ap=self.ap,
|
||||
speaker_id=None,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
||||
use_griffin_lim=use_gl,
|
||||
speaker_embedding=speaker_embedding,
|
||||
)
|
||||
if not use_gl:
|
||||
|
@ -152,12 +170,19 @@ class Synthesizer(object):
|
|||
# renormalize spectrogram based on vocoder config
|
||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||
# compute scale factor for possible sample rate mismatch
|
||||
scale_factor = [1, self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate]
|
||||
scale_factor = [
|
||||
1,
|
||||
self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate,
|
||||
]
|
||||
if scale_factor[1] != 1:
|
||||
print(" > interpolating tts model output.")
|
||||
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
||||
vocoder_input = interpolate_vocoder_input(
|
||||
scale_factor, vocoder_input
|
||||
)
|
||||
else:
|
||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(
|
||||
0
|
||||
) # pylint: disable=not-callable
|
||||
# run vocoder model
|
||||
# [1, T, C]
|
||||
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
||||
|
|
Loading…
Reference in New Issue