Added web view for TTS to get url (#13882)

* Added web view for to get url

* Added web view for TTS to get url

* Added web view for TTS to get url

* Added web view for TTS to get url

* Fixed test

* added auth

* Update __init__.py
pull/13960/head
Tod Schmidt 2018-04-17 09:24:54 -04:00 committed by Pascal Vizeli
parent 3b44f91395
commit f4b1a8e42d
2 changed files with 117 additions and 44 deletions

View File

@ -37,6 +37,7 @@ ATTR_CACHE = 'cache'
ATTR_LANGUAGE = 'language'
ATTR_MESSAGE = 'message'
ATTR_OPTIONS = 'options'
ATTR_PLATFORM = 'platform'
CONF_CACHE = 'cache'
CONF_CACHE_DIR = 'cache_dir'
@ -77,8 +78,7 @@ SCHEMA_SERVICE_SAY = vol.Schema({
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
@asyncio.coroutine
def async_setup(hass, config):
async def async_setup(hass, config):
"""Set up TTS."""
tts = SpeechManager(hass)
@ -88,27 +88,27 @@ def async_setup(hass, config):
cache_dir = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
time_memory = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
yield from tts.async_init_cache(use_cache, cache_dir, time_memory)
await tts.async_init_cache(use_cache, cache_dir, time_memory)
except (HomeAssistantError, KeyError) as err:
_LOGGER.error("Error on cache init %s", err)
return False
hass.http.register_view(TextToSpeechView(tts))
hass.http.register_view(TextToSpeechUrlView(tts))
@asyncio.coroutine
def async_setup_platform(p_type, p_config, disc_info=None):
async def async_setup_platform(p_type, p_config, disc_info=None):
"""Set up a TTS platform."""
platform = yield from async_prepare_setup_platform(
platform = await async_prepare_setup_platform(
hass, config, DOMAIN, p_type)
if platform is None:
return
try:
if hasattr(platform, 'async_get_engine'):
provider = yield from platform.async_get_engine(
provider = await platform.async_get_engine(
hass, p_config)
else:
provider = yield from hass.async_add_job(
provider = await hass.async_add_job(
platform.get_engine, hass, p_config)
if provider is None:
@ -120,8 +120,7 @@ def async_setup(hass, config):
_LOGGER.exception("Error setting up platform %s", p_type)
return
@asyncio.coroutine
def async_say_handle(service):
async def async_say_handle(service):
"""Service handle for say."""
entity_ids = service.data.get(ATTR_ENTITY_ID)
message = service.data.get(ATTR_MESSAGE)
@ -130,7 +129,7 @@ def async_setup(hass, config):
options = service.data.get(ATTR_OPTIONS)
try:
url = yield from tts.async_get_url(
url = await tts.async_get_url(
p_type, message, cache=cache, language=language,
options=options
)
@ -146,7 +145,7 @@ def async_setup(hass, config):
if entity_ids:
data[ATTR_ENTITY_ID] = entity_ids
yield from hass.services.async_call(
await hass.services.async_call(
DOMAIN_MP, SERVICE_PLAY_MEDIA, data, blocking=True)
hass.services.async_register(
@ -157,12 +156,11 @@ def async_setup(hass, config):
in config_per_platform(config, DOMAIN)]
if setup_tasks:
yield from asyncio.wait(setup_tasks, loop=hass.loop)
await asyncio.wait(setup_tasks, loop=hass.loop)
@asyncio.coroutine
def async_clear_cache_handle(service):
async def async_clear_cache_handle(service):
"""Handle clear cache service call."""
yield from tts.async_clear_cache()
await tts.async_clear_cache()
hass.services.async_register(
DOMAIN, SERVICE_CLEAR_CACHE, async_clear_cache_handle,
@ -185,8 +183,7 @@ class SpeechManager(object):
self.file_cache = {}
self.mem_cache = {}
@asyncio.coroutine
def async_init_cache(self, use_cache, cache_dir, time_memory):
async def async_init_cache(self, use_cache, cache_dir, time_memory):
"""Init config folder and load file cache."""
self.use_cache = use_cache
self.time_memory = time_memory
@ -201,7 +198,7 @@ class SpeechManager(object):
return cache_dir
try:
self.cache_dir = yield from self.hass.async_add_job(
self.cache_dir = await self.hass.async_add_job(
init_tts_cache_dir, cache_dir)
except OSError as err:
raise HomeAssistantError("Can't init cache dir {}".format(err))
@ -222,15 +219,14 @@ class SpeechManager(object):
return cache
try:
cache_files = yield from self.hass.async_add_job(get_cache_files)
cache_files = await self.hass.async_add_job(get_cache_files)
except OSError as err:
raise HomeAssistantError("Can't read cache dir {}".format(err))
if cache_files:
self.file_cache.update(cache_files)
@asyncio.coroutine
def async_clear_cache(self):
async def async_clear_cache(self):
"""Read file cache and delete files."""
self.mem_cache = {}
@ -243,7 +239,7 @@ class SpeechManager(object):
_LOGGER.warning(
"Can't remove cache file '%s': %s", filename, err)
yield from self.hass.async_add_job(remove_files)
await self.hass.async_add_job(remove_files)
self.file_cache = {}
@callback
@ -254,9 +250,8 @@ class SpeechManager(object):
provider.name = engine
self.providers[engine] = provider
@asyncio.coroutine
def async_get_url(self, engine, message, cache=None, language=None,
options=None):
async def async_get_url(self, engine, message, cache=None, language=None,
options=None):
"""Get URL for play message.
This method is a coroutine.
@ -301,21 +296,20 @@ class SpeechManager(object):
self.hass.async_add_job(self.async_file_to_mem(key))
# Load speech from provider into memory
else:
filename = yield from self.async_get_tts_audio(
filename = await self.async_get_tts_audio(
engine, key, message, use_cache, language, options)
return "{}/api/tts_proxy/{}".format(
self.hass.config.api.base_url, filename)
@asyncio.coroutine
def async_get_tts_audio(self, engine, key, message, cache, language,
options):
async def async_get_tts_audio(self, engine, key, message, cache, language,
options):
"""Receive TTS and store for view in cache.
This method is a coroutine.
"""
provider = self.providers[engine]
extension, data = yield from provider.async_get_tts_audio(
extension, data = await provider.async_get_tts_audio(
message, language, options)
if data is None or extension is None:
@ -337,8 +331,7 @@ class SpeechManager(object):
return filename
@asyncio.coroutine
def async_save_tts_audio(self, key, filename, data):
async def async_save_tts_audio(self, key, filename, data):
"""Store voice data to file and file_cache.
This method is a coroutine.
@ -351,13 +344,12 @@ class SpeechManager(object):
speech.write(data)
try:
yield from self.hass.async_add_job(save_speech)
await self.hass.async_add_job(save_speech)
self.file_cache[key] = filename
except OSError:
_LOGGER.error("Can't write %s", filename)
@asyncio.coroutine
def async_file_to_mem(self, key):
async def async_file_to_mem(self, key):
"""Load voice from file cache into memory.
This method is a coroutine.
@ -374,7 +366,7 @@ class SpeechManager(object):
return speech.read()
try:
data = yield from self.hass.async_add_job(load_speech)
data = await self.hass.async_add_job(load_speech)
except OSError:
del self.file_cache[key]
raise HomeAssistantError("Can't read {}".format(voice_file))
@ -396,8 +388,7 @@ class SpeechManager(object):
self.hass.loop.call_later(self.time_memory, async_remove_from_mem)
@asyncio.coroutine
def async_read_tts(self, filename):
async def async_read_tts(self, filename):
"""Read a voice file and return binary.
This method is a coroutine.
@ -412,7 +403,7 @@ class SpeechManager(object):
if key not in self.mem_cache:
if key not in self.file_cache:
raise HomeAssistantError("{} not in cache!".format(key))
yield from self.async_file_to_mem(key)
await self.async_file_to_mem(key)
content, _ = mimetypes.guess_type(filename)
return (content, self.mem_cache[key][MEM_CACHE_VOICE])
@ -490,6 +481,45 @@ class Provider(object):
ft.partial(self.get_tts_audio, message, language, options=options))
class TextToSpeechUrlView(HomeAssistantView):
"""TTS view to get a url to a generated speech file."""
requires_auth = True
url = '/api/tts_get_url'
name = 'api:tts:geturl'
def __init__(self, tts):
"""Initialize a tts view."""
self.tts = tts
async def post(self, request):
"""Generate speech and provide url."""
try:
data = await request.json()
except ValueError:
return self.json_message('Invalid JSON specified', 400)
if not data.get(ATTR_PLATFORM) and data.get(ATTR_MESSAGE):
return self.json_message('Must specify platform and message', 400)
p_type = data[ATTR_PLATFORM]
message = data[ATTR_MESSAGE]
cache = data.get(ATTR_CACHE)
language = data.get(ATTR_LANGUAGE)
options = data.get(ATTR_OPTIONS)
try:
url = await self.tts.async_get_url(
p_type, message, cache=cache, language=language,
options=options
)
resp = self.json({'url': url}, 200)
except HomeAssistantError as err:
_LOGGER.error("Error on init tts: %s", err)
resp = self.json({'error': err}, 400)
return resp
class TextToSpeechView(HomeAssistantView):
"""TTS view to serve a speech audio."""
@ -501,11 +531,10 @@ class TextToSpeechView(HomeAssistantView):
"""Initialize a tts view."""
self.tts = tts
@asyncio.coroutine
def get(self, request, filename):
async def get(self, request, filename):
"""Start a get request."""
try:
content, data = yield from self.tts.async_read_tts(filename)
content, data = await self.tts.async_read_tts(filename)
except HomeAssistantError as err:
_LOGGER.error("Error on load tts: %s", err)
return web.Response(status=404)

View File

@ -2,6 +2,7 @@
import ctypes
import os
import shutil
import json
from unittest.mock import patch, PropertyMock
import pytest
@ -353,7 +354,7 @@ class TestTTS(object):
demo_data = tts.SpeechManager.write_tags(
"265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3",
demo_data, self.demo_provider,
"I person is on front of your door.", 'en', None)
"AI person is in front of your door.", 'en', None)
assert req.status_code == 200
assert req.content == demo_data
@ -562,3 +563,46 @@ class TestTTS(object):
req = requests.get(url)
assert req.status_code == 200
assert req.content == demo_data
def test_setup_component_and_web_get_url(self):
"""Setup the demo platform and receive wrong file from web."""
config = {
tts.DOMAIN: {
'platform': 'demo',
}
}
with assert_setup_component(1, tts.DOMAIN):
setup_component(self.hass, tts.DOMAIN, config)
self.hass.start()
url = ("{}/api/tts_get_url").format(self.hass.config.api.base_url)
data = {'platform': 'demo',
'message': "I person is on front of your door."}
req = requests.post(url, data=json.dumps(data))
assert req.status_code == 200
response = json.loads(req.text)
assert response.get('url') == (("{}/api/tts_proxy/265944c108cbb00b2a62"
"1be5930513e03a0bb2cd_en_-_demo.mp3")
.format(self.hass.config.api.base_url))
def test_setup_component_and_web_get_url_bad_config(self):
"""Setup the demo platform and receive wrong file from web."""
config = {
tts.DOMAIN: {
'platform': 'demo',
}
}
with assert_setup_component(1, tts.DOMAIN):
setup_component(self.hass, tts.DOMAIN, config)
self.hass.start()
url = ("{}/api/tts_get_url").format(self.hass.config.api.base_url)
data = {'message': "I person is on front of your door."}
req = requests.post(url, data=data)
assert req.status_code == 400