mirror of https://github.com/coqui-ai/TTS.git
server update for changing r value
parent
97ffa2b44e
commit
e02fc51fde
|
@ -9,6 +9,7 @@ from utils.audio import AudioProcessor
|
||||||
from utils.generic_utils import load_config, setup_model
|
from utils.generic_utils import load_config, setup_model
|
||||||
from utils.text import phoneme_to_sequence, phonemes, symbols, text_to_sequence, sequence_to_phoneme
|
from utils.text import phoneme_to_sequence, phonemes, symbols, text_to_sequence, sequence_to_phoneme
|
||||||
from utils.speakers import load_speaker_mapping
|
from utils.speakers import load_speaker_mapping
|
||||||
|
from utils.synthesis import *
|
||||||
|
|
||||||
import re
|
import re
|
||||||
alphabets = r"([A-Za-z])"
|
alphabets = r"([A-Za-z])"
|
||||||
|
@ -41,28 +42,25 @@ class Synthesizer(object):
|
||||||
self.ap = AudioProcessor(**self.tts_config.audio)
|
self.ap = AudioProcessor(**self.tts_config.audio)
|
||||||
if self.use_phonemes:
|
if self.use_phonemes:
|
||||||
self.input_size = len(phonemes)
|
self.input_size = len(phonemes)
|
||||||
self.input_adapter = lambda sen: phoneme_to_sequence(sen, [self.tts_config.text_cleaner], self.tts_config.phoneme_language, self.tts_config.enable_eos_bos_chars)
|
|
||||||
else:
|
else:
|
||||||
self.input_size = len(symbols)
|
self.input_size = len(symbols)
|
||||||
self.input_adapter = lambda sen: text_to_sequence(sen, [self.tts_config.text_cleaner])
|
|
||||||
# load speakers
|
# load speakers
|
||||||
if self.config.tts_speakers is not None:
|
if self.config.tts_speakers is not None:
|
||||||
self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers))
|
self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers))
|
||||||
num_speakers = len(self.tts_speakers)
|
num_speakers = len(self.tts_speakers)
|
||||||
else:
|
else:
|
||||||
num_speakers = 0
|
num_speakers = 0
|
||||||
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers , c=self.tts_config)
|
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config)
|
||||||
# load model state
|
# load model state
|
||||||
if use_cuda:
|
cp = torch.load(self.model_file)
|
||||||
cp = torch.load(self.model_file)
|
|
||||||
else:
|
|
||||||
cp = torch.load(self.model_file, map_location=lambda storage, loc: storage)
|
|
||||||
# load the model
|
# load the model
|
||||||
self.tts_model.load_state_dict(cp['model'])
|
self.tts_model.load_state_dict(cp['model'])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
self.tts_model.eval()
|
self.tts_model.eval()
|
||||||
self.tts_model.decoder.max_decoder_steps = 3000
|
self.tts_model.decoder.max_decoder_steps = 3000
|
||||||
|
if 'r' in cp:
|
||||||
|
self.tts_model.decoder.set_r(cp['r'])
|
||||||
|
|
||||||
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
||||||
# TODO: set a function in wavernn code base for model setup and call it here.
|
# TODO: set a function in wavernn code base for model setup and call it here.
|
||||||
|
@ -136,33 +134,27 @@ class Synthesizer(object):
|
||||||
def tts(self, text):
|
def tts(self, text):
|
||||||
wavs = []
|
wavs = []
|
||||||
sens = self.split_into_sentences(text)
|
sens = self.split_into_sentences(text)
|
||||||
|
print(sens)
|
||||||
if not sens:
|
if not sens:
|
||||||
sens = [text+'.']
|
sens = [text+'.']
|
||||||
for sen in sens:
|
for sen in sens:
|
||||||
if len(sen) < 3:
|
# preprocess the given text
|
||||||
continue
|
inputs = text_to_seqvec(text, self.tts_config, self.use_cuda)
|
||||||
sen = sen.strip()
|
# synthesize voice
|
||||||
print(sen)
|
decoder_output, postnet_output, alignments, stop_tokens = run_model(
|
||||||
|
self.tts_model, inputs, self.tts_config, False, None, None)
|
||||||
|
# convert outputs to numpy
|
||||||
|
postnet_output, decoder_output, alignment = parse_outputs(
|
||||||
|
postnet_output, decoder_output, alignments)
|
||||||
|
|
||||||
seq = np.array(self.input_adapter(sen))
|
if self.wavernn:
|
||||||
if self.use_phonemes:
|
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||||
text_hat = sequence_to_phoneme(seq)
|
wav = self.wavernn.generate(torch.FloatTensor(postnet_output.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550)
|
||||||
print(text_hat)
|
else:
|
||||||
|
wav = inv_spectrogram(postnet_output, self.ap, self.tts_config)
|
||||||
|
# trim silence
|
||||||
|
wav = trim_silence(wav, self.ap)
|
||||||
|
|
||||||
chars_var = torch.from_numpy(seq).unsqueeze(0).long()
|
|
||||||
|
|
||||||
if self.use_cuda:
|
|
||||||
chars_var = chars_var.cuda()
|
|
||||||
decoder_out, postnet_out, alignments, stop_tokens = self.tts_model.inference(
|
|
||||||
chars_var)
|
|
||||||
postnet_out = postnet_out[0].data.cpu().numpy()
|
|
||||||
if self.tts_config.model == "Tacotron":
|
|
||||||
wav = self.ap.inv_spectrogram(postnet_out.T)
|
|
||||||
elif self.tts_config.model == "Tacotron2":
|
|
||||||
if self.wavernn:
|
|
||||||
wav = self.wavernn.generate(torch.FloatTensor(postnet_out.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550)
|
|
||||||
else:
|
|
||||||
wav = self.ap.inv_mel_spectrogram(postnet_out.T)
|
|
||||||
wavs += list(wav)
|
wavs += list(wav)
|
||||||
wavs += [0] * 10000
|
wavs += [0] * 10000
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ def parse_outputs(postnet_output, decoder_output, alignments):
|
||||||
return postnet_output, decoder_output, alignment
|
return postnet_output, decoder_output, alignment
|
||||||
|
|
||||||
|
|
||||||
def trim_silence(wav):
|
def trim_silence(wav, ap):
|
||||||
return wav[:ap.find_endpoint(wav)]
|
return wav[:ap.find_endpoint(wav)]
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,5 +114,5 @@ def synthesis(model,
|
||||||
wav = inv_spectrogram(postnet_output, ap, CONFIG)
|
wav = inv_spectrogram(postnet_output, ap, CONFIG)
|
||||||
# trim silence
|
# trim silence
|
||||||
if do_trim_silence:
|
if do_trim_silence:
|
||||||
wav = trim_silence(wav)
|
wav = trim_silence(wav, ap)
|
||||||
return wav, alignment, decoder_output, postnet_output, stop_tokens
|
return wav, alignment, decoder_output, postnet_output, stop_tokens
|
||||||
|
|
Loading…
Reference in New Issue