change demo server, return alignment with syntehsis, plot alighment on eval

pull/10/head
Michael Nguyen 2018-07-03 18:53:07 -05:00
parent 8f1f0155c5
commit 5afe3ce027
4 changed files with 88 additions and 55 deletions

View File

@ -1,9 +1,14 @@
import argparse
import falcon
from flask import Flask, request, send_file
from flask.views import MethodView
from hparams import hparams, hparams_debug_string
import argparse
import os
from synthesizer import Synthesizer
from flask_cors import CORS
import io
app = Flask(__name__)
CORS(app)
html_body = '''<html><title>Demo</title>
<style>
@ -56,40 +61,44 @@ function synthesize(text) {
</script></body></html>
'''
class UIResource:
def on_get(self, req, res):
res.content_type = 'text/html'
res.body = html_body
class SynthesisResource:
def on_get(self, req, res):
if not req.params.get('text'):
raise falcon.HTTPBadRequest()
res.data = synthesizer.synthesize(req.params.get('text'))
res.content_type = 'audio/wav'
synthesizer = Synthesizer()
api = falcon.API()
api.add_route('/synthesize', SynthesisResource())
api.add_route('/', UIResource())
class Mimic2(MethodView):
def get(self):
text = request.args.get('text')
if text:
wav, _ = synthesizer.synthesize(text)
audio = io.BytesIO(wav)
return send_file(audio, mimetype="audio/wav")
class UI(MethodView):
def get(self):
return html_body
ui_view = UI.as_view('ui_view')
app.add_url_rule('/', view_func=ui_view, methods=['GET'])
mimic2_api = Mimic2.as_view('mimic2_api')
app.add_url_rule('/synthesize', view_func=mimic2_api, methods=['GET'])
if __name__ == '__main__':
from wsgiref import simple_server
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True, help='Full path to model checkpoint')
parser.add_argument('--port', type=int, default=9000)
parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
args = parser.parse_args()
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
hparams.parse(args.hparams)
print(hparams_debug_string())
synthesizer.load(args.checkpoint)
print('Serving on port %d' % args.port)
simple_server.make_server('0.0.0.0', args.port, api).serve_forever()
else:
synthesizer.load(os.environ['CHECKPOINT'])
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True,
help='Full path to model checkpoint')
parser.add_argument('--port', type=int, default=3000)
parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
parser.add_argument(
'--gpu_assignment', default='0',
help='Set the gpu the model should run on')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_assignment
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
hparams.parse(args.hparams)
print(hparams_debug_string())
synthesizer.load(args.checkpoint)
app.run(host='0.0.0.0', port=3000)

53
eval.py
View File

@ -3,19 +3,30 @@ import os
import re
from hparams import hparams, hparams_debug_string
from synthesizer import Synthesizer
from util import plot
sentences = [
# From July 8, 2017 New York Times:
'Scientists at the CERN laboratory say they have discovered a new particle.',
'Theres a way to measure the acute emotional intelligence that has never gone out of style.',
'President Trump met with other leaders at the Group of 20 conference.',
'The Senate\'s bill to repeal and replace the Affordable Care Act is now imperiled.',
# From Google's Tacotron example page:
'Generative adversarial network or variational auto-encoder.',
'The buses aren\'t the problem, they actually provide a solution.',
'Does the quick brown fox jump over the lazy dog?',
'Talib Kweli confirmed to AllHipHop that he will be releasing an album in the next year.',
# From July 8, 2017 New York Times:
# 'Scientists at the CERN laboratory say they have discovered a new particle.',
# 'Theres a way to measure the acute emotional intelligence that has never gone out of style.',
# 'President Trump met with other leaders at the Group of 20 conference.',
# 'The Senate\'s bill to repeal and replace the Affordable Care Act is now imperiled.',
# # From Google's Tacotron example page:
# 'Generative adversarial network or variational auto-encoder.',
# 'The buses aren\'t the problem, they actually provide a solution.',
# 'Does the quick brown fox jump over the lazy dog?',
# 'Talib Kweli confirmed to AllHipHop that he will be releasing an album in the next year.',
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"The human voice is the most perfect instrument of all.",
"I'm sorry Dave, I'm afraid I can't do that.",
"This cake is great, It's so delicious and moist.",
"hello my name is mycroft.",
"hi.",
"wow.",
"cool.",
"great.",
]
@ -32,18 +43,26 @@ def run_eval(args):
synth.load(args.checkpoint)
base_path = get_output_base_path(args.checkpoint)
for i, text in enumerate(sentences):
path = '%s-%d.wav' % (base_path, i)
print('Synthesizing: %s' % path)
with open(path, 'wb') as f:
f.write(synth.synthesize(text))
wav_path = '%s-%d.wav' % (base_path, i)
align_path = '%s-%d.png' % (base_path, i)
print('Synthesizing and plotting: %s' % wav_path)
wav, alignment = synth.synthesize(text)
with open(wav_path, 'wb') as f:
f.write(wav)
plot.plot_alignment(
alignment, align_path,
info='%s' % (text)
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint')
parser.add_argument('--checkpoint', required=True,
help='Path to model checkpoint')
parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
parser.add_argument('--force_cpu', default=False, help='Force synthesize with cpu')
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
parser.add_argument('--force_cpu', default=False,
help='Force synthesize with cpu')
args = parser.parse_args()
if args.force_cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = ''

View File

@ -1,8 +1,9 @@
# Note: this doesn't include tensorflow or tensorflow-gpu because the package you need to install
# depends on your platform. It is assumed you have already installed tensorflow.
falcon==1.2.0
librosa==0.5.1
matplotlib==2.0.2
numpy==1.13.0
scipy==0.19.0
tqdm==4.11.2
flask_cors
flask

View File

@ -17,6 +17,7 @@ class Synthesizer:
self.model = create_model(model_name, hparams)
self.model.initialize(inputs, input_lengths)
self.wav_output = audio.inv_spectrogram_tensorflow(self.model.linear_outputs[0])
self.alignment = self.model.alignments[0]
print('Loading checkpoint: %s' % checkpoint_path)
self.session = tf.Session()
@ -32,8 +33,11 @@ class Synthesizer:
self.model.inputs: [np.asarray(seq, dtype=np.int32)],
self.model.input_lengths: np.asarray([len(seq)], dtype=np.int32)
}
wav = self.session.run(self.wav_output, feed_dict=feed_dict)
wav, alignment = self.session.run(
[self.wav_output, self.alignment],
feed_dict=feed_dict)
wav = wav[:audio.find_endpoint(wav)]
out = io.BytesIO()
audio.save_wav(wav, out)
return out.getvalue()
return out.getvalue(), alignment