mirror of https://github.com/MycroftAI/mimic2.git
abstracted synthesize_helper out
parent
3ea41b8820
commit
126ef34d44
|
@ -3,13 +3,19 @@ from flask.views import MethodView
|
|||
from hparams import hparams, hparams_debug_string
|
||||
import argparse
|
||||
import os
|
||||
from util import audio
|
||||
from synthesizer import Synthesizer
|
||||
from flask_cors import CORS
|
||||
import io
|
||||
import numpy as np
|
||||
import math
|
||||
from synthesize_helper import synthesize_helper
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
use_synthesize_helper = False
|
||||
|
||||
html_body = '''<html><title>Demo</title>
|
||||
<style>
|
||||
body {padding: 16px; font-family: sans-serif; font-size: 14px; color: #444}
|
||||
|
@ -63,14 +69,18 @@ function synthesize(text) {
|
|||
|
||||
synthesizer = Synthesizer()
|
||||
|
||||
|
||||
class Mimic2(MethodView):
|
||||
def get(self):
|
||||
text = request.args.get('text')
|
||||
if text:
|
||||
wav, _ = synthesizer.synthesize(text)
|
||||
audio = io.BytesIO(wav)
|
||||
return send_file(audio, mimetype="audio/wav")
|
||||
if use_synthesize_helper:
|
||||
wav = synthesize_helper(text, synthesizer)
|
||||
# wav, _ = synthesizer.synthesize(text)
|
||||
audio = io.BytesIO(wav)
|
||||
return send_file(audio, mimetype="audio/wav")
|
||||
else:
|
||||
wav, _ = synthesizer.synthesize(text)
|
||||
return send_file(wav, mimetype="audio/wav")
|
||||
|
||||
|
||||
class UI(MethodView):
|
||||
|
@ -95,7 +105,14 @@ if __name__ == '__main__':
|
|||
parser.add_argument(
|
||||
'--gpu_assignment', default='0',
|
||||
help='Set the gpu the model should run on')
|
||||
parser.add_argument(
|
||||
'--synthezier_helper', default=False, action="store_false",
|
||||
help='uses the synthesize helper during sythesis'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
use_synthesize_helper = args.synthezier_helper
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_assignment
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
hparams.parse(args.hparams)
|
||||
|
|
|
@ -27,7 +27,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
wav_output = tf.saved_model.utils.build_tensor_info(synth.wav_output)
|
||||
alignment = tf.saved_model.utils.build_tensor_info(synth.alignment)
|
||||
# alignment = tf.saved_model.utils.build_tensor_info(synth.alignment)
|
||||
|
||||
prediction_signature = (
|
||||
tf.saved_model.signature_def_utils.build_signature_def(
|
||||
|
@ -36,8 +36,8 @@ if __name__ == "__main__":
|
|||
"input_lengths": input_lengths
|
||||
},
|
||||
outputs={
|
||||
'wav_output': wav_output,
|
||||
'alignment': alignment
|
||||
'wav_output': wav_output #,
|
||||
# 'alignment': alignment
|
||||
},
|
||||
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
|
||||
)
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
"""An optional module to post preprocess the text before syntehsizing
|
||||
"""
|
||||
import io
|
||||
import math
|
||||
from util import audio
|
||||
import numpy as np
|
||||
|
||||
punctuations = ['.', '?', '!']
|
||||
|
||||
split_punctuations = [',', '.', '-', '?', '!', ':', ';']
|
||||
|
||||
letter_lookup = {
|
||||
'A': 'ayy',
|
||||
'B': 'bee',
|
||||
'C': 'see',
|
||||
'D': 'dee',
|
||||
'E': 'eee',
|
||||
'F': 'eff',
|
||||
'G': 'jee',
|
||||
'H': 'aitch',
|
||||
'I': 'eye',
|
||||
'J': 'jay',
|
||||
'K': 'kay',
|
||||
'L': 'el',
|
||||
'M': 'em',
|
||||
'N': 'en',
|
||||
'O': 'oow',
|
||||
'P': 'pee',
|
||||
'Q': 'queue',
|
||||
'R': 'are',
|
||||
'S': 'es',
|
||||
'T': 'tee',
|
||||
'U': 'you',
|
||||
'V': 'vee',
|
||||
'W': 'double you',
|
||||
'X': 'ex',
|
||||
'Y': 'why',
|
||||
'Z': 'zee'
|
||||
}
|
||||
|
||||
def replace_acronym(text):
|
||||
for idx, word in enumerate(text):
|
||||
if len(word) == 1:
|
||||
continue
|
||||
if word.isupper():
|
||||
sound = ""
|
||||
for letter in word.strip():
|
||||
if letter_lookup.get(letter):
|
||||
sound += letter_lookup.get(letter) + " "
|
||||
text[idx] = sound
|
||||
return text
|
||||
|
||||
def add_punctuation(text):
|
||||
if len(text) < 1:
|
||||
return text
|
||||
if len(text) < 10:
|
||||
if text[-1] in punctuations:
|
||||
if text[-1] != ".":
|
||||
return text[:-1] + "."
|
||||
if text[-1] not in punctuations:
|
||||
text += '.'
|
||||
return text
|
||||
|
||||
def break_chunks(l, n):
|
||||
"""Yield successive n-sized chunks from l."""
|
||||
for i in range(0, len(l), n):
|
||||
yield " ".join(l[i:i + n])
|
||||
|
||||
def split_by_threshold(text, threshold):
|
||||
text_list = text.split()
|
||||
|
||||
if len(text_list) <= threshold:
|
||||
return [text]
|
||||
|
||||
if threshold < len(text_list) < (threshold*2):
|
||||
return list(break_chunks(
|
||||
text_list,
|
||||
int(math.ceil(len(text_list) / 2))
|
||||
))
|
||||
elif (threshold*2) < len(text_list) < (threshold*3):
|
||||
return list(break_chunks(
|
||||
text_list,
|
||||
int(math.ceil(len(text_list) / 3))
|
||||
))
|
||||
elif (threshold*3) < len(text_list) < (threshold*4):
|
||||
return list(break_chunks(
|
||||
text_list,
|
||||
int(math.ceil(len(text_list) / 4))
|
||||
))
|
||||
else:
|
||||
return list(break_chunks(
|
||||
text_list,
|
||||
int(math.ceil(len(text_list) / 4))
|
||||
))
|
||||
|
||||
def synthesize_helper(text, synthesizer, threshold=10):
|
||||
text_list = text.split()
|
||||
if len(text_list) <= threshold*1.3:
|
||||
text = " ".join(replace_acronym(text_list))
|
||||
print(text.encode('utf-8'))
|
||||
wav, _ = synthesizer.synthesize(add_punctuation(text), return_wav=True)
|
||||
out = io.BytesIO()
|
||||
audio.save_wav(wav, out)
|
||||
return out.getvalue()
|
||||
|
||||
split_by_punc = None
|
||||
if len(text_list) >= threshold:
|
||||
for punc in split_punctuations:
|
||||
if punc in text:
|
||||
split_by_punc = text.split(punc)
|
||||
break
|
||||
|
||||
chunks = []
|
||||
if split_by_punc:
|
||||
for sentence in split_by_punc:
|
||||
sentence = sentence.strip()
|
||||
chunk = split_by_threshold(sentence, threshold)
|
||||
chunks += split_by_threshold(sentence, threshold)
|
||||
else:
|
||||
chunks += split_by_threshold(text, threshold)
|
||||
|
||||
combined_wav = np.array([])
|
||||
for idx, chunk in enumerate(chunks):
|
||||
if len(chunk) > 0:
|
||||
text = add_punctuation(chunk)
|
||||
text = " ".join(replace_acronym(text.split()))
|
||||
print(text.encode('utf-8'))
|
||||
wav, _ = synthesizer.synthesize(text, return_wav=True)
|
||||
combined_wav = np.concatenate((combined_wav, wav[:-880*6]))
|
||||
|
||||
out = io.BytesIO()
|
||||
audio.save_wav(combined_wav, out)
|
||||
return out.getvalue()
|
|
@ -31,7 +31,7 @@ class Synthesizer:
|
|||
saver = tf.train.Saver()
|
||||
saver.restore(self.session, checkpoint_path)
|
||||
|
||||
def synthesize(self, text):
|
||||
def synthesize(self, text, return_wav=False):
|
||||
cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
|
||||
seq = text_to_sequence(text, cleaner_names)
|
||||
feed_dict = {
|
||||
|
@ -51,6 +51,9 @@ class Synthesizer:
|
|||
wav = wav[:audio_endpoint]
|
||||
alignment = alignment[:, :alignment_endpoint]
|
||||
|
||||
if return_wav:
|
||||
return wav, alignment
|
||||
|
||||
out = io.BytesIO()
|
||||
audio.save_wav(wav, out)
|
||||
return out.getvalue(), alignment
|
||||
return out, alignment
|
||||
|
|
Loading…
Reference in New Issue