2018-06-05 12:15:48 +00:00
|
|
|
import io
|
|
|
|
import os
|
2019-04-15 14:13:33 +00:00
|
|
|
import sys
|
|
|
|
|
2018-06-05 12:15:48 +00:00
|
|
|
import numpy as np
|
2019-04-15 14:13:33 +00:00
|
|
|
import torch
|
|
|
|
|
2018-11-19 14:27:22 +00:00
|
|
|
from models.tacotron import Tacotron
|
2019-04-15 14:13:33 +00:00
|
|
|
from utils.audio import AudioProcessor
|
|
|
|
from utils.generic_utils import load_config, setup_model
|
|
|
|
from utils.text import phoneme_to_sequence, phonemes, symbols, text_to_sequence, sequence_to_phoneme
|
2018-06-05 12:15:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Synthesizer(object):
|
2019-04-15 14:13:33 +00:00
|
|
|
def __init__(self, config):
|
|
|
|
self.wavernn = None
|
|
|
|
self.config = config
|
|
|
|
self.use_cuda = config.use_cuda
|
|
|
|
if self.use_cuda:
|
|
|
|
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
|
|
|
self.load_tts(self.config.tts_path, self.config.tts_file, self.config.tts_config, config.use_cuda)
|
|
|
|
if self.config.wavernn_lib_path:
|
|
|
|
self.load_wavernn(config.wavernn_lib_path, config.wavernn_path, config.wavernn_file, config.wavernn_config, config.use_cuda)
|
|
|
|
|
|
|
|
def load_tts(self, model_path, model_file, model_config, use_cuda):
|
|
|
|
tts_config = os.path.join(model_path, model_config)
|
|
|
|
self.model_file = os.path.join(model_path, model_file)
|
|
|
|
print(" > Loading TTS model ...")
|
|
|
|
print(" | > model config: ", tts_config)
|
|
|
|
print(" | > model file: ", model_file)
|
|
|
|
self.tts_config = load_config(tts_config)
|
|
|
|
self.use_phonemes = self.tts_config.use_phonemes
|
|
|
|
self.ap = AudioProcessor(**self.tts_config.audio)
|
|
|
|
if self.use_phonemes:
|
|
|
|
self.input_size = len(phonemes)
|
|
|
|
self.input_adapter = lambda sen: phoneme_to_sequence(sen, [self.tts_config.text_cleaner], self.tts_config.phoneme_language, self.tts_config.enable_eos_bos_chars)
|
|
|
|
else:
|
|
|
|
self.input_size = len(symbols)
|
|
|
|
self.input_adapter = lambda sen: text_to_sequence(sen, [self.tts_config.text_cleaner])
|
|
|
|
self.tts_model = setup_model(self.input_size, self.tts_config)
|
2018-06-05 12:15:48 +00:00
|
|
|
# load model state
|
|
|
|
if use_cuda:
|
|
|
|
cp = torch.load(self.model_file)
|
|
|
|
else:
|
2019-04-15 14:13:33 +00:00
|
|
|
cp = torch.load(self.model_file, map_location=lambda storage, loc: storage)
|
2018-06-05 12:15:48 +00:00
|
|
|
# load the model
|
2019-04-15 14:13:33 +00:00
|
|
|
self.tts_model.load_state_dict(cp['model'])
|
2018-06-05 12:15:48 +00:00
|
|
|
if use_cuda:
|
2019-04-15 14:13:33 +00:00
|
|
|
self.tts_model.cuda()
|
|
|
|
self.tts_model.eval()
|
|
|
|
|
|
|
|
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
|
|
|
sys.path.append(lib_path) # set this if TTS is not installed globally
|
|
|
|
from WaveRNN.models.wavernn import Model
|
|
|
|
wavernn_config = os.path.join(model_path, model_config)
|
|
|
|
model_file = os.path.join(model_path, model_file)
|
|
|
|
print(" > Loading WaveRNN model ...")
|
|
|
|
print(" | > model config: ", wavernn_config)
|
|
|
|
print(" | > model file: ", model_file)
|
|
|
|
self.wavernn_config = load_config(wavernn_config)
|
|
|
|
self.wavernn = Model(
|
|
|
|
rnn_dims=512,
|
|
|
|
fc_dims=512,
|
|
|
|
mode=self.wavernn_config.mode,
|
|
|
|
pad=2,
|
|
|
|
upsample_factors=self.wavernn_config.upsample_factors, # set this depending on dataset
|
|
|
|
feat_dims=80,
|
|
|
|
compute_dims=128,
|
|
|
|
res_out_dims=128,
|
|
|
|
res_blocks=10,
|
|
|
|
hop_length=self.ap.hop_length,
|
|
|
|
sample_rate=self.ap.sample_rate,
|
|
|
|
).cuda()
|
|
|
|
|
|
|
|
check = torch.load(model_file)
|
|
|
|
self.wavernn.load_state_dict(check['model'])
|
|
|
|
if use_cuda:
|
|
|
|
self.wavernn.cuda()
|
|
|
|
self.wavernn.eval()
|
2018-08-02 14:34:17 +00:00
|
|
|
|
2018-06-05 12:15:48 +00:00
|
|
|
def save_wav(self, wav, path):
|
2018-11-19 14:27:22 +00:00
|
|
|
# wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
2018-12-11 14:10:56 +00:00
|
|
|
wav = np.array(wav)
|
2018-11-19 14:27:22 +00:00
|
|
|
self.ap.save_wav(wav, path)
|
2018-06-05 12:15:48 +00:00
|
|
|
|
|
|
|
def tts(self, text):
|
|
|
|
wavs = []
|
|
|
|
for sen in text.split('.'):
|
|
|
|
if len(sen) < 3:
|
|
|
|
continue
|
2018-06-06 14:30:45 +00:00
|
|
|
sen = sen.strip()
|
2018-06-05 14:15:57 +00:00
|
|
|
print(sen)
|
2018-06-05 12:15:48 +00:00
|
|
|
sen = sen.strip()
|
2019-04-15 14:13:33 +00:00
|
|
|
|
|
|
|
seq = np.array(self.input_adapter(sen))
|
|
|
|
text_hat = sequence_to_phoneme(seq)
|
|
|
|
print(text_hat)
|
|
|
|
|
2018-08-02 14:34:17 +00:00
|
|
|
chars_var = torch.from_numpy(seq).unsqueeze(0).long()
|
2019-04-15 14:13:33 +00:00
|
|
|
|
2018-06-05 12:15:48 +00:00
|
|
|
if self.use_cuda:
|
|
|
|
chars_var = chars_var.cuda()
|
2019-04-15 14:13:33 +00:00
|
|
|
decoder_out, postnet_out, alignments, stop_tokens = self.tts_model.inference(
|
2018-08-02 14:34:17 +00:00
|
|
|
chars_var)
|
2019-04-15 14:13:33 +00:00
|
|
|
postnet_out = postnet_out[0].data.cpu().numpy()
|
|
|
|
if self.tts_config.model == "Tacotron":
|
|
|
|
wav = self.ap.inv_spectrogram(postnet_out.T)
|
|
|
|
elif self.tts_config.model == "Tacotron2":
|
|
|
|
if self.wavernn:
|
|
|
|
wav = self.wavernn.generate(torch.FloatTensor(postnet_out.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550)
|
|
|
|
else:
|
|
|
|
wav = self.ap.inv_mel_spectrogram(postnet_out.T)
|
2018-12-11 14:10:56 +00:00
|
|
|
wavs += list(wav)
|
|
|
|
wavs += [0] * 10000
|
2019-04-15 14:13:33 +00:00
|
|
|
|
|
|
|
out = io.BytesIO()
|
2018-12-11 14:10:56 +00:00
|
|
|
self.save_wav(wavs, out)
|
2019-04-15 14:13:33 +00:00
|
|
|
return out
|