Added web view for TTS to get url (#13882)
* Added web view for to get url * Added web view for TTS to get url * Added web view for TTS to get url * Added web view for TTS to get url * Fixed test * added auth * Update __init__.pypull/13960/head
parent
3b44f91395
commit
f4b1a8e42d
|
@ -37,6 +37,7 @@ ATTR_CACHE = 'cache'
|
|||
ATTR_LANGUAGE = 'language'
|
||||
ATTR_MESSAGE = 'message'
|
||||
ATTR_OPTIONS = 'options'
|
||||
ATTR_PLATFORM = 'platform'
|
||||
|
||||
CONF_CACHE = 'cache'
|
||||
CONF_CACHE_DIR = 'cache_dir'
|
||||
|
@ -77,8 +78,7 @@ SCHEMA_SERVICE_SAY = vol.Schema({
|
|||
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup(hass, config):
|
||||
async def async_setup(hass, config):
|
||||
"""Set up TTS."""
|
||||
tts = SpeechManager(hass)
|
||||
|
||||
|
@ -88,27 +88,27 @@ def async_setup(hass, config):
|
|||
cache_dir = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
|
||||
time_memory = conf.get(CONF_TIME_MEMORY, DEFAULT_TIME_MEMORY)
|
||||
|
||||
yield from tts.async_init_cache(use_cache, cache_dir, time_memory)
|
||||
await tts.async_init_cache(use_cache, cache_dir, time_memory)
|
||||
except (HomeAssistantError, KeyError) as err:
|
||||
_LOGGER.error("Error on cache init %s", err)
|
||||
return False
|
||||
|
||||
hass.http.register_view(TextToSpeechView(tts))
|
||||
hass.http.register_view(TextToSpeechUrlView(tts))
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup_platform(p_type, p_config, disc_info=None):
|
||||
async def async_setup_platform(p_type, p_config, disc_info=None):
|
||||
"""Set up a TTS platform."""
|
||||
platform = yield from async_prepare_setup_platform(
|
||||
platform = await async_prepare_setup_platform(
|
||||
hass, config, DOMAIN, p_type)
|
||||
if platform is None:
|
||||
return
|
||||
|
||||
try:
|
||||
if hasattr(platform, 'async_get_engine'):
|
||||
provider = yield from platform.async_get_engine(
|
||||
provider = await platform.async_get_engine(
|
||||
hass, p_config)
|
||||
else:
|
||||
provider = yield from hass.async_add_job(
|
||||
provider = await hass.async_add_job(
|
||||
platform.get_engine, hass, p_config)
|
||||
|
||||
if provider is None:
|
||||
|
@ -120,8 +120,7 @@ def async_setup(hass, config):
|
|||
_LOGGER.exception("Error setting up platform %s", p_type)
|
||||
return
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_say_handle(service):
|
||||
async def async_say_handle(service):
|
||||
"""Service handle for say."""
|
||||
entity_ids = service.data.get(ATTR_ENTITY_ID)
|
||||
message = service.data.get(ATTR_MESSAGE)
|
||||
|
@ -130,7 +129,7 @@ def async_setup(hass, config):
|
|||
options = service.data.get(ATTR_OPTIONS)
|
||||
|
||||
try:
|
||||
url = yield from tts.async_get_url(
|
||||
url = await tts.async_get_url(
|
||||
p_type, message, cache=cache, language=language,
|
||||
options=options
|
||||
)
|
||||
|
@ -146,7 +145,7 @@ def async_setup(hass, config):
|
|||
if entity_ids:
|
||||
data[ATTR_ENTITY_ID] = entity_ids
|
||||
|
||||
yield from hass.services.async_call(
|
||||
await hass.services.async_call(
|
||||
DOMAIN_MP, SERVICE_PLAY_MEDIA, data, blocking=True)
|
||||
|
||||
hass.services.async_register(
|
||||
|
@ -157,12 +156,11 @@ def async_setup(hass, config):
|
|||
in config_per_platform(config, DOMAIN)]
|
||||
|
||||
if setup_tasks:
|
||||
yield from asyncio.wait(setup_tasks, loop=hass.loop)
|
||||
await asyncio.wait(setup_tasks, loop=hass.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_clear_cache_handle(service):
|
||||
async def async_clear_cache_handle(service):
|
||||
"""Handle clear cache service call."""
|
||||
yield from tts.async_clear_cache()
|
||||
await tts.async_clear_cache()
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_CLEAR_CACHE, async_clear_cache_handle,
|
||||
|
@ -185,8 +183,7 @@ class SpeechManager(object):
|
|||
self.file_cache = {}
|
||||
self.mem_cache = {}
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_init_cache(self, use_cache, cache_dir, time_memory):
|
||||
async def async_init_cache(self, use_cache, cache_dir, time_memory):
|
||||
"""Init config folder and load file cache."""
|
||||
self.use_cache = use_cache
|
||||
self.time_memory = time_memory
|
||||
|
@ -201,7 +198,7 @@ class SpeechManager(object):
|
|||
return cache_dir
|
||||
|
||||
try:
|
||||
self.cache_dir = yield from self.hass.async_add_job(
|
||||
self.cache_dir = await self.hass.async_add_job(
|
||||
init_tts_cache_dir, cache_dir)
|
||||
except OSError as err:
|
||||
raise HomeAssistantError("Can't init cache dir {}".format(err))
|
||||
|
@ -222,15 +219,14 @@ class SpeechManager(object):
|
|||
return cache
|
||||
|
||||
try:
|
||||
cache_files = yield from self.hass.async_add_job(get_cache_files)
|
||||
cache_files = await self.hass.async_add_job(get_cache_files)
|
||||
except OSError as err:
|
||||
raise HomeAssistantError("Can't read cache dir {}".format(err))
|
||||
|
||||
if cache_files:
|
||||
self.file_cache.update(cache_files)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_clear_cache(self):
|
||||
async def async_clear_cache(self):
|
||||
"""Read file cache and delete files."""
|
||||
self.mem_cache = {}
|
||||
|
||||
|
@ -243,7 +239,7 @@ class SpeechManager(object):
|
|||
_LOGGER.warning(
|
||||
"Can't remove cache file '%s': %s", filename, err)
|
||||
|
||||
yield from self.hass.async_add_job(remove_files)
|
||||
await self.hass.async_add_job(remove_files)
|
||||
self.file_cache = {}
|
||||
|
||||
@callback
|
||||
|
@ -254,9 +250,8 @@ class SpeechManager(object):
|
|||
provider.name = engine
|
||||
self.providers[engine] = provider
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_get_url(self, engine, message, cache=None, language=None,
|
||||
options=None):
|
||||
async def async_get_url(self, engine, message, cache=None, language=None,
|
||||
options=None):
|
||||
"""Get URL for play message.
|
||||
|
||||
This method is a coroutine.
|
||||
|
@ -301,21 +296,20 @@ class SpeechManager(object):
|
|||
self.hass.async_add_job(self.async_file_to_mem(key))
|
||||
# Load speech from provider into memory
|
||||
else:
|
||||
filename = yield from self.async_get_tts_audio(
|
||||
filename = await self.async_get_tts_audio(
|
||||
engine, key, message, use_cache, language, options)
|
||||
|
||||
return "{}/api/tts_proxy/{}".format(
|
||||
self.hass.config.api.base_url, filename)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_get_tts_audio(self, engine, key, message, cache, language,
|
||||
options):
|
||||
async def async_get_tts_audio(self, engine, key, message, cache, language,
|
||||
options):
|
||||
"""Receive TTS and store for view in cache.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
provider = self.providers[engine]
|
||||
extension, data = yield from provider.async_get_tts_audio(
|
||||
extension, data = await provider.async_get_tts_audio(
|
||||
message, language, options)
|
||||
|
||||
if data is None or extension is None:
|
||||
|
@ -337,8 +331,7 @@ class SpeechManager(object):
|
|||
|
||||
return filename
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_save_tts_audio(self, key, filename, data):
|
||||
async def async_save_tts_audio(self, key, filename, data):
|
||||
"""Store voice data to file and file_cache.
|
||||
|
||||
This method is a coroutine.
|
||||
|
@ -351,13 +344,12 @@ class SpeechManager(object):
|
|||
speech.write(data)
|
||||
|
||||
try:
|
||||
yield from self.hass.async_add_job(save_speech)
|
||||
await self.hass.async_add_job(save_speech)
|
||||
self.file_cache[key] = filename
|
||||
except OSError:
|
||||
_LOGGER.error("Can't write %s", filename)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_file_to_mem(self, key):
|
||||
async def async_file_to_mem(self, key):
|
||||
"""Load voice from file cache into memory.
|
||||
|
||||
This method is a coroutine.
|
||||
|
@ -374,7 +366,7 @@ class SpeechManager(object):
|
|||
return speech.read()
|
||||
|
||||
try:
|
||||
data = yield from self.hass.async_add_job(load_speech)
|
||||
data = await self.hass.async_add_job(load_speech)
|
||||
except OSError:
|
||||
del self.file_cache[key]
|
||||
raise HomeAssistantError("Can't read {}".format(voice_file))
|
||||
|
@ -396,8 +388,7 @@ class SpeechManager(object):
|
|||
|
||||
self.hass.loop.call_later(self.time_memory, async_remove_from_mem)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_read_tts(self, filename):
|
||||
async def async_read_tts(self, filename):
|
||||
"""Read a voice file and return binary.
|
||||
|
||||
This method is a coroutine.
|
||||
|
@ -412,7 +403,7 @@ class SpeechManager(object):
|
|||
if key not in self.mem_cache:
|
||||
if key not in self.file_cache:
|
||||
raise HomeAssistantError("{} not in cache!".format(key))
|
||||
yield from self.async_file_to_mem(key)
|
||||
await self.async_file_to_mem(key)
|
||||
|
||||
content, _ = mimetypes.guess_type(filename)
|
||||
return (content, self.mem_cache[key][MEM_CACHE_VOICE])
|
||||
|
@ -490,6 +481,45 @@ class Provider(object):
|
|||
ft.partial(self.get_tts_audio, message, language, options=options))
|
||||
|
||||
|
||||
class TextToSpeechUrlView(HomeAssistantView):
|
||||
"""TTS view to get a url to a generated speech file."""
|
||||
|
||||
requires_auth = True
|
||||
url = '/api/tts_get_url'
|
||||
name = 'api:tts:geturl'
|
||||
|
||||
def __init__(self, tts):
|
||||
"""Initialize a tts view."""
|
||||
self.tts = tts
|
||||
|
||||
async def post(self, request):
|
||||
"""Generate speech and provide url."""
|
||||
try:
|
||||
data = await request.json()
|
||||
except ValueError:
|
||||
return self.json_message('Invalid JSON specified', 400)
|
||||
if not data.get(ATTR_PLATFORM) and data.get(ATTR_MESSAGE):
|
||||
return self.json_message('Must specify platform and message', 400)
|
||||
|
||||
p_type = data[ATTR_PLATFORM]
|
||||
message = data[ATTR_MESSAGE]
|
||||
cache = data.get(ATTR_CACHE)
|
||||
language = data.get(ATTR_LANGUAGE)
|
||||
options = data.get(ATTR_OPTIONS)
|
||||
|
||||
try:
|
||||
url = await self.tts.async_get_url(
|
||||
p_type, message, cache=cache, language=language,
|
||||
options=options
|
||||
)
|
||||
resp = self.json({'url': url}, 200)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Error on init tts: %s", err)
|
||||
resp = self.json({'error': err}, 400)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
class TextToSpeechView(HomeAssistantView):
|
||||
"""TTS view to serve a speech audio."""
|
||||
|
||||
|
@ -501,11 +531,10 @@ class TextToSpeechView(HomeAssistantView):
|
|||
"""Initialize a tts view."""
|
||||
self.tts = tts
|
||||
|
||||
@asyncio.coroutine
|
||||
def get(self, request, filename):
|
||||
async def get(self, request, filename):
|
||||
"""Start a get request."""
|
||||
try:
|
||||
content, data = yield from self.tts.async_read_tts(filename)
|
||||
content, data = await self.tts.async_read_tts(filename)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Error on load tts: %s", err)
|
||||
return web.Response(status=404)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import ctypes
|
||||
import os
|
||||
import shutil
|
||||
import json
|
||||
from unittest.mock import patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
@ -353,7 +354,7 @@ class TestTTS(object):
|
|||
demo_data = tts.SpeechManager.write_tags(
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_en_-_demo.mp3",
|
||||
demo_data, self.demo_provider,
|
||||
"I person is on front of your door.", 'en', None)
|
||||
"AI person is in front of your door.", 'en', None)
|
||||
assert req.status_code == 200
|
||||
assert req.content == demo_data
|
||||
|
||||
|
@ -562,3 +563,46 @@ class TestTTS(object):
|
|||
req = requests.get(url)
|
||||
assert req.status_code == 200
|
||||
assert req.content == demo_data
|
||||
|
||||
def test_setup_component_and_web_get_url(self):
|
||||
"""Setup the demo platform and receive wrong file from web."""
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.start()
|
||||
|
||||
url = ("{}/api/tts_get_url").format(self.hass.config.api.base_url)
|
||||
data = {'platform': 'demo',
|
||||
'message': "I person is on front of your door."}
|
||||
|
||||
req = requests.post(url, data=json.dumps(data))
|
||||
assert req.status_code == 200
|
||||
response = json.loads(req.text)
|
||||
assert response.get('url') == (("{}/api/tts_proxy/265944c108cbb00b2a62"
|
||||
"1be5930513e03a0bb2cd_en_-_demo.mp3")
|
||||
.format(self.hass.config.api.base_url))
|
||||
|
||||
def test_setup_component_and_web_get_url_bad_config(self):
|
||||
"""Setup the demo platform and receive wrong file from web."""
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.start()
|
||||
|
||||
url = ("{}/api/tts_get_url").format(self.hass.config.api.base_url)
|
||||
data = {'message': "I person is on front of your door."}
|
||||
|
||||
req = requests.post(url, data=data)
|
||||
assert req.status_code == 400
|
||||
|
|
Loading…
Reference in New Issue