diff --git a/demo_server.py b/demo_server.py index baac6b4..4ba3918 100644 --- a/demo_server.py +++ b/demo_server.py @@ -9,7 +9,7 @@ from flask_cors import CORS import io import numpy as np import math -from synthesize_helper import synthesize_helper +from synthesize_helper import synthesize_helper, replace_acronym, custom_splitter app = Flask(__name__) CORS(app) @@ -72,6 +72,7 @@ synthesizer = Synthesizer() class Mimic2(MethodView): def get(self): text = request.args.get('text') + text = " ".join(replace_acronym(custom_splitter(text))) if text: if use_synthesize_helper: wav = synthesize_helper(text, synthesizer) diff --git a/export.py b/export.py index 37ddcb0..46e2337 100644 --- a/export.py +++ b/export.py @@ -26,8 +26,12 @@ if __name__ == "__main__": synth.model.input_lengths ) - wav_output = tf.saved_model.utils.build_tensor_info(synth.wav_output) - # alignment = tf.saved_model.utils.build_tensor_info(synth.alignment) + w_o = audio.inv_spectrogram_tensorflow( + synth.model.linear_outputs + ) + + wav_output = tf.saved_model.utils.build_tensor_info(w_o) + alignment = tf.saved_model.utils.build_tensor_info(synth.model.alignments) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( @@ -36,8 +40,8 @@ if __name__ == "__main__": "input_lengths": input_lengths }, outputs={ - 'wav_output': wav_output #, - # 'alignment': alignment + 'wav_output': wav_output, + 'alignment': alignment }, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) ) diff --git a/synthesize_helper.py b/synthesize_helper.py index 1b743c2..a2175d1 100644 --- a/synthesize_helper.py +++ b/synthesize_helper.py @@ -40,6 +40,8 @@ letter_lookup = { def replace_acronym(text): for idx, word in enumerate(text): + if "{" in word and "}" in word: + continue if len(word) == 1: continue if word.isupper(): @@ -50,6 +52,19 @@ def replace_acronym(text): text[idx] = sound return text +def custom_splitter(text): + if "{" in text and "}" in text: + acc = [] + split = text.split("}") + for word in split: + if "{" in word: + acc.append(word + "}") + else: + acc.append(word) + return acc + else: + return text.split() + def add_punctuation(text): if len(text) < 1: return text