2018-06-05 12:15:48 +00:00
import io
import os
2019-03-11 20:56:40 +00:00
2018-06-05 12:15:48 +00:00
import numpy as np
2019-03-11 20:56:40 +00:00
import torch
2019-04-15 14:13:33 +00:00
import sys
from utils.audio import AudioProcessor
from utils.generic_utils import load_config, setup_model
from utils.text import phoneme_to_sequence, phonemes, symbols, text_to_sequence, sequence_to_phoneme
2018-06-05 12:15:48 +00:00
2019-04-18 15:35:20 +00:00
import re
2019-07-19 06:46:23 +00:00
alphabets = r"([A-Za-z])"
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
starters = r"(Mr|Mrs|Ms|Dr|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
acronyms = r"([A-Z][.][A-Z][.](?:[A-Z][.])?)"
websites = r"[.](com|net|org|io|gov)"
2018-06-05 12:15:48 +00:00
class Synthesizer(object):
2019-04-15 14:13:33 +00:00
def __init__(self, config):
self.wavernn = None
2019-07-19 06:46:23 +00:00
self.config = config
2019-04-15 14:13:33 +00:00
self.use_cuda = config.use_cuda
if self.use_cuda:
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
self.load_tts(self.config.tts_path, self.config.tts_file, self.config.tts_config, config.use_cuda)
if self.config.wavernn_lib_path:
self.load_wavernn(config.wavernn_lib_path, config.wavernn_path, config.wavernn_file, config.wavernn_config, config.use_cuda)
2019-03-11 20:56:40 +00:00
2019-04-15 14:13:33 +00:00
def load_tts(self, model_path, model_file, model_config, use_cuda):
tts_config = os.path.join(model_path, model_config)
self.model_file = os.path.join(model_path, model_file)
print(" > Loading TTS model ...")
print(" | > model config: ", tts_config)
print(" | > model file: ", model_file)
self.tts_config = load_config(tts_config)
self.use_phonemes = self.tts_config.use_phonemes
self.ap = AudioProcessor(**self.tts_config.audio)
2019-03-11 20:56:40 +00:00
if self.use_phonemes:
self.input_size = len(phonemes)
2019-04-15 14:13:33 +00:00
self.input_adapter = lambda sen: phoneme_to_sequence(sen, [self.tts_config.text_cleaner], self.tts_config.phoneme_language, self.tts_config.enable_eos_bos_chars)
2019-03-11 20:56:40 +00:00
self.input_size = len(symbols)
2019-04-15 14:13:33 +00:00
self.input_adapter = lambda sen: text_to_sequence(sen, [self.tts_config.text_cleaner])
2019-07-19 06:46:23 +00:00
self.tts_model = setup_model(self.input_size, c=self.tts_config) #FIXME: missing num_speakers argument to setup_model
2018-06-05 12:15:48 +00:00
# load model state
if use_cuda:
cp = torch.load(self.model_file)
2019-03-11 20:56:40 +00:00
cp = torch.load(self.model_file, map_location=lambda storage, loc: storage)
2018-06-05 12:15:48 +00:00
# load the model
2019-04-15 14:13:33 +00:00
2018-06-05 12:15:48 +00:00
if use_cuda:
2019-04-15 14:13:33 +00:00
2019-04-18 15:35:20 +00:00
self.tts_model.decoder.max_decoder_steps = 3000
2019-04-15 14:13:33 +00:00
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
sys.path.append(lib_path) # set this if TTS is not installed globally
from WaveRNN.models.wavernn import Model
wavernn_config = os.path.join(model_path, model_config)
model_file = os.path.join(model_path, model_file)
print(" > Loading WaveRNN model ...")
print(" | > model config: ", wavernn_config)
print(" | > model file: ", model_file)
self.wavernn_config = load_config(wavernn_config)
self.wavernn = Model(
2019-07-19 06:46:23 +00:00
upsample_factors=self.wavernn_config.upsample_factors, # set this depending on dataset
2019-04-15 14:13:33 +00:00
check = torch.load(model_file)
if use_cuda:
2018-08-02 14:34:17 +00:00
2018-06-05 12:15:48 +00:00
def save_wav(self, wav, path):
2018-11-19 14:27:22 +00:00
# wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
2018-12-11 14:10:56 +00:00
wav = np.array(wav)
2018-11-19 14:27:22 +00:00
self.ap.save_wav(wav, path)
2018-06-05 12:15:48 +00:00
2019-04-18 15:35:20 +00:00
def split_into_sentences(self, text):
text = " " + text + " "
2019-07-19 06:46:23 +00:00
text = text.replace("\n", " ")
text = re.sub(prefixes, "\\1<prd>", text)
text = re.sub(websites, "<prd>\\1", text)
if "Ph.D" in text:
text = text.replace("Ph.D.", "Ph<prd>D<prd>")
text = re.sub(r"\s" + alphabets + "[.] ", " \\1<prd> ", text)
text = re.sub(acronyms+" "+starters, "\\1<stop> \\2", text)
text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]", "\\1<prd>\\2<prd>\\3<prd>", text)
text = re.sub(alphabets + "[.]" + alphabets + "[.]", "\\1<prd>\\2<prd>", text)
text = re.sub(" "+suffixes+"[.] "+starters, " \\1<stop> \\2", text)
text = re.sub(" "+suffixes+"[.]", " \\1<prd>", text)
text = re.sub(" " + alphabets + "[.]", " \\1<prd>", text)
if "”" in text:
text = text.replace(".”", "”.")
if "\"" in text:
text = text.replace(".\"", "\".")
if "!" in text:
text = text.replace("!\"", "\"!")
if "?" in text:
text = text.replace("?\"", "\"?")
text = text.replace(".", ".<stop>")
text = text.replace("?", "?<stop>")
text = text.replace("!", "!<stop>")
text = text.replace("<prd>", ".")
2019-04-18 15:35:20 +00:00
sentences = text.split("<stop>")
sentences = sentences[:-1]
sentences = [s.strip() for s in sentences]
return sentences
2018-06-05 12:15:48 +00:00
def tts(self, text):
wavs = []
2019-04-18 15:35:20 +00:00
sens = self.split_into_sentences(text)
2019-07-19 06:46:23 +00:00
if not sens:
2019-04-18 15:35:20 +00:00
sens = [text+'.']
for sen in sens:
2018-06-05 12:15:48 +00:00
if len(sen) < 3:
2018-06-06 14:30:45 +00:00
sen = sen.strip()
2018-06-05 14:15:57 +00:00
2019-04-15 14:13:33 +00:00
seq = np.array(self.input_adapter(sen))
text_hat = sequence_to_phoneme(seq)
2019-03-11 20:56:40 +00:00
2018-08-02 14:34:17 +00:00
chars_var = torch.from_numpy(seq).unsqueeze(0).long()
2019-04-15 14:13:33 +00:00
2018-06-05 12:15:48 +00:00
if self.use_cuda:
chars_var = chars_var.cuda()
2019-04-15 14:13:33 +00:00
decoder_out, postnet_out, alignments, stop_tokens = self.tts_model.inference(
2018-08-02 14:34:17 +00:00
2019-04-15 14:13:33 +00:00
postnet_out = postnet_out[0].data.cpu().numpy()
if self.tts_config.model == "Tacotron":
wav = self.ap.inv_spectrogram(postnet_out.T)
elif self.tts_config.model == "Tacotron2":
if self.wavernn:
wav = self.wavernn.generate(torch.FloatTensor(postnet_out.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550)
wav = self.ap.inv_mel_spectrogram(postnet_out.T)
2018-12-11 14:10:56 +00:00
wavs += list(wav)
wavs += [0] * 10000
2019-03-11 20:56:40 +00:00
out = io.BytesIO()
2018-12-11 14:10:56 +00:00
self.save_wav(wavs, out)
2019-03-11 20:56:40 +00:00
return out