Merge pull request #10 from MycroftAI/demo_server

change demo server, return alignment with syntehsis, plot alighment o…
pull/15/head
Michael Nguyen 2018-07-03 18:54:33 -05:00 committed by GitHub
commit bda7cdad7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 88 additions and 55 deletions

View File

@ -1,9 +1,14 @@
import argparse from flask import Flask, request, send_file
import falcon from flask.views import MethodView
from hparams import hparams, hparams_debug_string from hparams import hparams, hparams_debug_string
import argparse
import os import os
from synthesizer import Synthesizer from synthesizer import Synthesizer
from flask_cors import CORS
import io
app = Flask(__name__)
CORS(app)
html_body = '''<html><title>Demo</title> html_body = '''<html><title>Demo</title>
<style> <style>
@ -56,40 +61,44 @@ function synthesize(text) {
</script></body></html> </script></body></html>
''' '''
class UIResource:
def on_get(self, req, res):
res.content_type = 'text/html'
res.body = html_body
class SynthesisResource:
def on_get(self, req, res):
if not req.params.get('text'):
raise falcon.HTTPBadRequest()
res.data = synthesizer.synthesize(req.params.get('text'))
res.content_type = 'audio/wav'
synthesizer = Synthesizer() synthesizer = Synthesizer()
api = falcon.API()
api.add_route('/synthesize', SynthesisResource())
api.add_route('/', UIResource()) class Mimic2(MethodView):
def get(self):
text = request.args.get('text')
if text:
wav, _ = synthesizer.synthesize(text)
audio = io.BytesIO(wav)
return send_file(audio, mimetype="audio/wav")
class UI(MethodView):
def get(self):
return html_body
ui_view = UI.as_view('ui_view')
app.add_url_rule('/', view_func=ui_view, methods=['GET'])
mimic2_api = Mimic2.as_view('mimic2_api')
app.add_url_rule('/synthesize', view_func=mimic2_api, methods=['GET'])
if __name__ == '__main__': if __name__ == '__main__':
from wsgiref import simple_server parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser() parser.add_argument('--checkpoint', required=True,
parser.add_argument('--checkpoint', required=True, help='Full path to model checkpoint') help='Full path to model checkpoint')
parser.add_argument('--port', type=int, default=9000) parser.add_argument('--port', type=int, default=3000)
parser.add_argument('--hparams', default='', parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs') help='Hyperparameter overrides as a comma-separated list of name=value pairs')
args = parser.parse_args() parser.add_argument(
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' '--gpu_assignment', default='0',
hparams.parse(args.hparams) help='Set the gpu the model should run on')
print(hparams_debug_string()) args = parser.parse_args()
synthesizer.load(args.checkpoint) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_assignment
print('Serving on port %d' % args.port) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
simple_server.make_server('0.0.0.0', args.port, api).serve_forever() hparams.parse(args.hparams)
else: print(hparams_debug_string())
synthesizer.load(os.environ['CHECKPOINT']) synthesizer.load(args.checkpoint)
app.run(host='0.0.0.0', port=3000)

53
eval.py
View File

@ -3,19 +3,30 @@ import os
import re import re
from hparams import hparams, hparams_debug_string from hparams import hparams, hparams_debug_string
from synthesizer import Synthesizer from synthesizer import Synthesizer
from util import plot
sentences = [ sentences = [
# From July 8, 2017 New York Times: # From July 8, 2017 New York Times:
'Scientists at the CERN laboratory say they have discovered a new particle.', # 'Scientists at the CERN laboratory say they have discovered a new particle.',
'Theres a way to measure the acute emotional intelligence that has never gone out of style.', # 'Theres a way to measure the acute emotional intelligence that has never gone out of style.',
'President Trump met with other leaders at the Group of 20 conference.', # 'President Trump met with other leaders at the Group of 20 conference.',
'The Senate\'s bill to repeal and replace the Affordable Care Act is now imperiled.', # 'The Senate\'s bill to repeal and replace the Affordable Care Act is now imperiled.',
# From Google's Tacotron example page: # # From Google's Tacotron example page:
'Generative adversarial network or variational auto-encoder.', # 'Generative adversarial network or variational auto-encoder.',
'The buses aren\'t the problem, they actually provide a solution.', # 'The buses aren\'t the problem, they actually provide a solution.',
'Does the quick brown fox jump over the lazy dog?', # 'Does the quick brown fox jump over the lazy dog?',
'Talib Kweli confirmed to AllHipHop that he will be releasing an album in the next year.', # 'Talib Kweli confirmed to AllHipHop that he will be releasing an album in the next year.',
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"The human voice is the most perfect instrument of all.",
"I'm sorry Dave, I'm afraid I can't do that.",
"This cake is great, It's so delicious and moist.",
"hello my name is mycroft.",
"hi.",
"wow.",
"cool.",
"great.",
] ]
@ -32,18 +43,26 @@ def run_eval(args):
synth.load(args.checkpoint) synth.load(args.checkpoint)
base_path = get_output_base_path(args.checkpoint) base_path = get_output_base_path(args.checkpoint)
for i, text in enumerate(sentences): for i, text in enumerate(sentences):
path = '%s-%d.wav' % (base_path, i) wav_path = '%s-%d.wav' % (base_path, i)
print('Synthesizing: %s' % path) align_path = '%s-%d.png' % (base_path, i)
with open(path, 'wb') as f: print('Synthesizing and plotting: %s' % wav_path)
f.write(synth.synthesize(text)) wav, alignment = synth.synthesize(text)
with open(wav_path, 'wb') as f:
f.write(wav)
plot.plot_alignment(
alignment, align_path,
info='%s' % (text)
)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint') parser.add_argument('--checkpoint', required=True,
help='Path to model checkpoint')
parser.add_argument('--hparams', default='', parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs') help='Hyperparameter overrides as a comma-separated list of name=value pairs')
parser.add_argument('--force_cpu', default=False, help='Force synthesize with cpu') parser.add_argument('--force_cpu', default=False,
help='Force synthesize with cpu')
args = parser.parse_args() args = parser.parse_args()
if args.force_cpu: if args.force_cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '' os.environ['CUDA_VISIBLE_DEVICES'] = ''

View File

@ -1,8 +1,9 @@
# Note: this doesn't include tensorflow or tensorflow-gpu because the package you need to install # Note: this doesn't include tensorflow or tensorflow-gpu because the package you need to install
# depends on your platform. It is assumed you have already installed tensorflow. # depends on your platform. It is assumed you have already installed tensorflow.
falcon==1.2.0
librosa==0.5.1 librosa==0.5.1
matplotlib==2.0.2 matplotlib==2.0.2
numpy==1.13.0 numpy==1.13.0
scipy==0.19.0 scipy==0.19.0
tqdm==4.11.2 tqdm==4.11.2
flask_cors
flask

View File

@ -17,6 +17,7 @@ class Synthesizer:
self.model = create_model(model_name, hparams) self.model = create_model(model_name, hparams)
self.model.initialize(inputs, input_lengths) self.model.initialize(inputs, input_lengths)
self.wav_output = audio.inv_spectrogram_tensorflow(self.model.linear_outputs[0]) self.wav_output = audio.inv_spectrogram_tensorflow(self.model.linear_outputs[0])
self.alignment = self.model.alignments[0]
print('Loading checkpoint: %s' % checkpoint_path) print('Loading checkpoint: %s' % checkpoint_path)
self.session = tf.Session() self.session = tf.Session()
@ -32,8 +33,11 @@ class Synthesizer:
self.model.inputs: [np.asarray(seq, dtype=np.int32)], self.model.inputs: [np.asarray(seq, dtype=np.int32)],
self.model.input_lengths: np.asarray([len(seq)], dtype=np.int32) self.model.input_lengths: np.asarray([len(seq)], dtype=np.int32)
} }
wav = self.session.run(self.wav_output, feed_dict=feed_dict) wav, alignment = self.session.run(
[self.wav_output, self.alignment],
feed_dict=feed_dict)
wav = wav[:audio.find_endpoint(wav)] wav = wav[:audio.find_endpoint(wav)]
out = io.BytesIO() out = io.BytesIO()
audio.save_wav(wav, out) audio.save_wav(wav, out)
return out.getvalue() return out.getvalue(), alignment