TTS Component / Google speech platform (#4837)
* TTS Component / Google speech platform * Change file backend handling / cache * Use mimetype / rename Provider function / allow cache on service call * Add a memcache for faster response * Add demo platform * First version of unittest * Address comments * improve error handling / address comments * Add google unittest & check http response code * Change url param handling * add test for other language * Change hash to sha256 for same hash on every os/hardware * add unittest for receive demo data * add test for error cases * Test case load from file to mem over aiohttp server * Use cache SpeechManager level, address other comments * Add service for clear cache * Update service.yaml * add support for spliting google messagepull/4750/merge
parent
acb841a1f4
commit
2dec38d8d4
|
@ -162,7 +162,7 @@ MEDIA_PLAYER_MEDIA_SEEK_SCHEMA = MEDIA_PLAYER_SCHEMA.extend({
|
|||
MEDIA_PLAYER_PLAY_MEDIA_SCHEMA = MEDIA_PLAYER_SCHEMA.extend({
|
||||
vol.Required(ATTR_MEDIA_CONTENT_TYPE): cv.string,
|
||||
vol.Required(ATTR_MEDIA_CONTENT_ID): cv.string,
|
||||
ATTR_MEDIA_ENQUEUE: cv.boolean,
|
||||
vol.Optional(ATTR_MEDIA_ENQUEUE): cv.boolean,
|
||||
})
|
||||
|
||||
MEDIA_PLAYER_SELECT_SOURCE_SCHEMA = MEDIA_PLAYER_SCHEMA.extend({
|
||||
|
|
|
@ -0,0 +1,421 @@
|
|||
"""
|
||||
Provide functionality to TTS.
|
||||
|
||||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/tts/
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.bootstrap import async_prepare_setup_platform
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.config import load_yaml_config_file
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.media_player import (
|
||||
SERVICE_PLAY_MEDIA, MEDIA_TYPE_MUSIC, ATTR_MEDIA_CONTENT_ID,
|
||||
ATTR_MEDIA_CONTENT_TYPE, DOMAIN as DOMAIN_MP)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_per_platform
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
DOMAIN = 'tts'
|
||||
DEPENDENCIES = ['http']
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
MEM_CACHE_FILENAME = 'filename'
|
||||
MEM_CACHE_VOICE = 'voice'
|
||||
|
||||
CONF_LANG = 'language'
|
||||
CONF_CACHE = 'cache'
|
||||
CONF_CACHE_DIR = 'cache_dir'
|
||||
CONF_TIME_MEMORY = 'time_memory'
|
||||
|
||||
DEFAULT_CACHE = True
|
||||
DEFAULT_CACHE_DIR = "tts"
|
||||
DEFAULT_LANG = 'en'
|
||||
DEFAULT_TIME_MEMORY = 300
|
||||
|
||||
SERVICE_SAY = 'say'
|
||||
SERVICE_CLEAR_CACHE = 'clear_cache'
|
||||
|
||||
ATTR_MESSAGE = 'message'
|
||||
ATTR_CACHE = 'cache'
|
||||
|
||||
_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([a-z]+)\.[a-z0-9]{3,4}")
|
||||
|
||||
PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend({
|
||||
vol.Optional(CONF_LANG, default=DEFAULT_LANG): cv.string,
|
||||
vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean,
|
||||
vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string,
|
||||
vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY):
|
||||
vol.All(vol.Coerce(int), vol.Range(min=60, max=57600)),
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_SERVICE_SAY = vol.Schema({
|
||||
vol.Required(ATTR_MESSAGE): cv.string,
|
||||
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
|
||||
vol.Optional(ATTR_CACHE): cv.boolean,
|
||||
})
|
||||
|
||||
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup(hass, config):
|
||||
"""Setup TTS."""
|
||||
tts = SpeechManager(hass)
|
||||
|
||||
try:
|
||||
conf = config[DOMAIN][0] if len(config.get(DOMAIN, [])) > 0 else {}
|
||||
use_cache = conf.get(CONF_CACHE, DEFAULT_CACHE)
|
||||
cache_dir = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR)
|
||||
time_memory = conf.get(CONF_TIME_MEMORY, DEFAULT_LANG)
|
||||
|
||||
yield from tts.async_init_cache(use_cache, cache_dir, time_memory)
|
||||
except (HomeAssistantError, KeyError) as err:
|
||||
_LOGGER.error("Error on cache init %s", err)
|
||||
return False
|
||||
|
||||
hass.http.register_view(TextToSpeechView(tts))
|
||||
|
||||
descriptions = yield from hass.loop.run_in_executor(
|
||||
None, load_yaml_config_file,
|
||||
os.path.join(os.path.dirname(__file__), 'services.yaml'))
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup_platform(p_type, p_config, disc_info=None):
|
||||
"""Setup a tts platform."""
|
||||
platform = yield from async_prepare_setup_platform(
|
||||
hass, config, DOMAIN, p_type)
|
||||
if platform is None:
|
||||
return
|
||||
|
||||
try:
|
||||
if hasattr(platform, 'async_get_engine'):
|
||||
provider = yield from platform.async_get_engine(
|
||||
hass, p_config)
|
||||
else:
|
||||
provider = yield from hass.loop.run_in_executor(
|
||||
None, platform.get_engine, hass, p_config)
|
||||
|
||||
if provider is None:
|
||||
_LOGGER.error('Error setting up platform %s', p_type)
|
||||
return
|
||||
|
||||
tts.async_register_engine(p_type, provider, p_config)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Error setting up platform %s', p_type)
|
||||
return
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_say_handle(service):
|
||||
"""Service handle for say."""
|
||||
entity_ids = service.data.get(ATTR_ENTITY_ID)
|
||||
message = service.data.get(ATTR_MESSAGE)
|
||||
cache = service.data.get(ATTR_CACHE)
|
||||
|
||||
try:
|
||||
url = yield from tts.async_get_url(
|
||||
p_type, message, cache=cache)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Error on init tts: %s", err)
|
||||
return
|
||||
|
||||
data = {
|
||||
ATTR_MEDIA_CONTENT_ID: url,
|
||||
ATTR_MEDIA_CONTENT_TYPE: MEDIA_TYPE_MUSIC,
|
||||
}
|
||||
|
||||
if entity_ids:
|
||||
data[ATTR_ENTITY_ID] = entity_ids
|
||||
|
||||
yield from hass.services.async_call(
|
||||
DOMAIN_MP, SERVICE_PLAY_MEDIA, data, blocking=True)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, "{}_{}".format(p_type, SERVICE_SAY), async_say_handle,
|
||||
descriptions.get(SERVICE_SAY), schema=SCHEMA_SERVICE_SAY)
|
||||
|
||||
setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config
|
||||
in config_per_platform(config, DOMAIN)]
|
||||
|
||||
if setup_tasks:
|
||||
yield from asyncio.wait(setup_tasks, loop=hass.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_clear_cache_handle(service):
|
||||
"""Handle clear cache service call."""
|
||||
yield from tts.async_clear_cache()
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_CLEAR_CACHE, async_clear_cache_handle,
|
||||
descriptions.get(SERVICE_CLEAR_CACHE), schema=SERVICE_CLEAR_CACHE)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class SpeechManager(object):
|
||||
"""Representation of a speech store."""
|
||||
|
||||
def __init__(self, hass):
|
||||
"""Initialize a speech store."""
|
||||
self.hass = hass
|
||||
self.providers = {}
|
||||
|
||||
self.use_cache = True
|
||||
self.cache_dir = None
|
||||
self.time_memory = None
|
||||
self.file_cache = {}
|
||||
self.mem_cache = {}
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_init_cache(self, use_cache, cache_dir, time_memory):
|
||||
"""Init config folder and load file cache."""
|
||||
self.use_cache = use_cache
|
||||
self.time_memory = time_memory
|
||||
|
||||
def init_tts_cache_dir(cache_dir):
|
||||
"""Init cache folder."""
|
||||
if not os.path.isabs(cache_dir):
|
||||
cache_dir = self.hass.config.path(cache_dir)
|
||||
if not os.path.isdir(cache_dir):
|
||||
_LOGGER.info("Create cache dir %s.", cache_dir)
|
||||
os.mkdir(cache_dir)
|
||||
return cache_dir
|
||||
|
||||
try:
|
||||
self.cache_dir = yield from self.hass.loop.run_in_executor(
|
||||
None, init_tts_cache_dir, cache_dir)
|
||||
except OSError as err:
|
||||
raise HomeAssistantError(
|
||||
"Can't init cache dir {}".format(err))
|
||||
|
||||
def get_cache_files():
|
||||
"""Return a dict of given engine files."""
|
||||
cache = {}
|
||||
|
||||
folder_data = os.listdir(self.cache_dir)
|
||||
for file_data in folder_data:
|
||||
record = _RE_VOICE_FILE.match(file_data)
|
||||
if record:
|
||||
key = "{}_{}".format(record.group(1), record.group(2))
|
||||
cache[key.lower()] = file_data.lower()
|
||||
return cache
|
||||
|
||||
try:
|
||||
cache_files = yield from self.hass.loop.run_in_executor(
|
||||
None, get_cache_files)
|
||||
except OSError as err:
|
||||
raise HomeAssistantError(
|
||||
"Can't read cache dir {}".format(err))
|
||||
|
||||
if cache_files:
|
||||
self.file_cache.update(cache_files)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_clear_cache(self):
|
||||
"""Read file cache and delete files."""
|
||||
self.mem_cache = {}
|
||||
|
||||
def remove_files():
|
||||
"""Remove files from filesystem."""
|
||||
for _, filename in self.file_cache.items():
|
||||
try:
|
||||
os.remove(os.path.join(self.cache_dir), filename)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
yield from self.hass.loop.run_in_executor(None, remove_files)
|
||||
self.file_cache = {}
|
||||
|
||||
@callback
|
||||
def async_register_engine(self, engine, provider, config):
|
||||
"""Register a TTS provider."""
|
||||
provider.hass = self.hass
|
||||
provider.language = config.get(CONF_LANG)
|
||||
self.providers[engine] = provider
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_get_url(self, engine, message, cache=None):
|
||||
"""Get URL for play message.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
msg_hash = hashlib.sha1(bytes(message, 'utf-8')).hexdigest()
|
||||
key = ("{}_{}".format(msg_hash, engine)).lower()
|
||||
use_cache = cache if cache is not None else self.use_cache
|
||||
|
||||
# is speech allready in memory
|
||||
if key in self.mem_cache:
|
||||
filename = self.mem_cache[key][MEM_CACHE_FILENAME]
|
||||
# is file store in file cache
|
||||
elif use_cache and key in self.file_cache:
|
||||
filename = self.file_cache[key]
|
||||
self.hass.async_add_job(self.async_file_to_mem(engine, key))
|
||||
# load speech from provider into memory
|
||||
else:
|
||||
filename = yield from self.async_get_tts_audio(
|
||||
engine, key, message, use_cache)
|
||||
|
||||
return "{}/api/tts_proxy/{}".format(
|
||||
self.hass.config.api.base_url, filename)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_get_tts_audio(self, engine, key, message, cache):
|
||||
"""Receive TTS and store for view in cache.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
provider = self.providers[engine]
|
||||
extension, data = yield from provider.async_get_tts_audio(message)
|
||||
|
||||
if data is None or extension is None:
|
||||
raise HomeAssistantError(
|
||||
"No TTS from {} for '{}'".format(engine, message))
|
||||
|
||||
# create file infos
|
||||
filename = ("{}.{}".format(key, extension)).lower()
|
||||
|
||||
# save to memory
|
||||
self._async_store_to_memcache(key, filename, data)
|
||||
|
||||
if cache:
|
||||
self.hass.async_add_job(
|
||||
self.async_save_tts_audio(key, filename, data))
|
||||
|
||||
return filename
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_save_tts_audio(self, key, filename, data):
|
||||
"""Store voice data to file and file_cache.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
voice_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
def save_speech():
|
||||
"""Store speech to filesystem."""
|
||||
with open(voice_file, 'wb') as speech:
|
||||
speech.write(data)
|
||||
|
||||
try:
|
||||
yield from self.hass.loop.run_in_executor(None, save_speech)
|
||||
self.file_cache[key] = filename
|
||||
except OSError:
|
||||
_LOGGER.error("Can't write %s", filename)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_file_to_mem(self, engine, key):
|
||||
"""Load voice from file cache into memory.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
filename = self.file_cache.get(key)
|
||||
if not filename:
|
||||
raise HomeAssistantError("Key {} not in file cache!".format(key))
|
||||
|
||||
voice_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
def load_speech():
|
||||
"""Load a speech from filesystem."""
|
||||
with open(voice_file, 'rb') as speech:
|
||||
return speech.read()
|
||||
|
||||
try:
|
||||
data = yield from self.hass.loop.run_in_executor(None, load_speech)
|
||||
except OSError:
|
||||
raise HomeAssistantError("Can't read {}".format(voice_file))
|
||||
|
||||
self._async_store_to_memcache(key, filename, data)
|
||||
|
||||
@callback
|
||||
def _async_store_to_memcache(self, key, filename, data):
|
||||
"""Store data to memcache and set timer to remove it."""
|
||||
self.mem_cache[key] = {
|
||||
MEM_CACHE_FILENAME: filename,
|
||||
MEM_CACHE_VOICE: data,
|
||||
}
|
||||
|
||||
@callback
|
||||
def async_remove_from_mem():
|
||||
"""Cleanup memcache."""
|
||||
self.mem_cache.pop(key)
|
||||
|
||||
self.hass.loop.call_later(self.time_memory, async_remove_from_mem)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_read_tts(self, filename):
|
||||
"""Read a voice file and return binary.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
record = _RE_VOICE_FILE.match(filename.lower())
|
||||
if not record:
|
||||
raise HomeAssistantError("Wrong tts file format!")
|
||||
|
||||
key = "{}_{}".format(record.group(1), record.group(2))
|
||||
|
||||
if key not in self.mem_cache:
|
||||
if key not in self.file_cache:
|
||||
raise HomeAssistantError("%s not in cache!", key)
|
||||
engine = record.group(2)
|
||||
yield from self.async_file_to_mem(engine, key)
|
||||
|
||||
content, _ = mimetypes.guess_type(filename)
|
||||
return (content, self.mem_cache[key][MEM_CACHE_VOICE])
|
||||
|
||||
|
||||
class Provider(object):
|
||||
"""Represent a single provider."""
|
||||
|
||||
hass = None
|
||||
language = DEFAULT_LANG
|
||||
|
||||
def get_tts_audio(self, message):
|
||||
"""Load tts audio file from provider."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_get_tts_audio(self, message):
|
||||
"""Load tts audio file from provider.
|
||||
|
||||
Return a tuple of file extension and data as bytes.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
extension, data = yield from self.hass.loop.run_in_executor(
|
||||
None, self.get_tts_audio, message)
|
||||
return (extension, data)
|
||||
|
||||
|
||||
class TextToSpeechView(HomeAssistantView):
|
||||
"""TTS view to serve an speech audio."""
|
||||
|
||||
requires_auth = False
|
||||
url = "/api/tts_proxy/{filename}"
|
||||
name = "api:tts:speech"
|
||||
|
||||
def __init__(self, tts):
|
||||
"""Initialize a tts view."""
|
||||
self.tts = tts
|
||||
|
||||
@asyncio.coroutine
|
||||
def get(self, request, filename):
|
||||
"""Start a get request."""
|
||||
try:
|
||||
content, data = yield from self.tts.async_read_tts(filename)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Error on load tts: %s", err)
|
||||
return web.Response(status=404)
|
||||
|
||||
return web.Response(body=data, content_type=content)
|
Binary file not shown.
|
@ -0,0 +1,29 @@
|
|||
"""
|
||||
Support for the demo speech service.
|
||||
|
||||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/demo/
|
||||
"""
|
||||
import os
|
||||
|
||||
from homeassistant.components.tts import Provider
|
||||
|
||||
|
||||
def get_engine(hass, config):
|
||||
"""Setup Demo speech component."""
|
||||
return DemoProvider()
|
||||
|
||||
|
||||
class DemoProvider(Provider):
|
||||
"""Demo speech api provider."""
|
||||
|
||||
def get_tts_audio(self, message):
|
||||
"""Load TTS from demo."""
|
||||
filename = os.path.join(os.path.dirname(__file__), "demo.mp3")
|
||||
try:
|
||||
with open(filename, 'rb') as voice:
|
||||
data = voice.read()
|
||||
except OSError:
|
||||
return
|
||||
|
||||
return ("mp3", data)
|
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
Support for the google speech service.
|
||||
|
||||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/tts/google/
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
|
||||
import aiohttp
|
||||
import async_timeout
|
||||
import yarl
|
||||
|
||||
from homeassistant.components.tts import Provider
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
|
||||
REQUIREMENTS = ["gTTS-token==1.1.1"]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
GOOGLE_SPEECH_URL = "http://translate.google.com/translate_tts"
|
||||
MESSAGE_SIZE = 148
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_get_engine(hass, config):
|
||||
"""Setup Google speech component."""
|
||||
return GoogleProvider(hass)
|
||||
|
||||
|
||||
class GoogleProvider(Provider):
|
||||
"""Google speech api provider."""
|
||||
|
||||
def __init__(self, hass):
|
||||
"""Init Google TTS service."""
|
||||
self.hass = hass
|
||||
self.headers = {
|
||||
'Referer': "http://translate.google.com/",
|
||||
'User-Agent': ("Mozilla/5.0 (Windows NT 10.0; WOW64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/47.0.2526.106 Safari/537.36")
|
||||
}
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_get_tts_audio(self, message):
|
||||
"""Load TTS from google."""
|
||||
from gtts_token import gtts_token
|
||||
|
||||
token = gtts_token.Token()
|
||||
websession = async_get_clientsession(self.hass)
|
||||
message_parts = self._split_message_to_parts(message)
|
||||
|
||||
data = b''
|
||||
for idx, part in enumerate(message_parts):
|
||||
part_token = yield from self.hass.loop.run_in_executor(
|
||||
None, token.calculate_token, part)
|
||||
|
||||
url_param = {
|
||||
'ie': 'UTF-8',
|
||||
'tl': self.language,
|
||||
'q': yarl.quote(part),
|
||||
'tk': part_token,
|
||||
'total': len(message_parts),
|
||||
'idx': idx,
|
||||
'client': 'tw-ob',
|
||||
'textlen': len(part),
|
||||
}
|
||||
|
||||
request = None
|
||||
try:
|
||||
with async_timeout.timeout(10, loop=self.hass.loop):
|
||||
request = yield from websession.get(
|
||||
GOOGLE_SPEECH_URL, params=url_param,
|
||||
headers=self.headers
|
||||
)
|
||||
|
||||
if request.status != 200:
|
||||
_LOGGER.error("Error %d on load url %s", request.code,
|
||||
request.url)
|
||||
return (None, None)
|
||||
data += yield from request.read()
|
||||
|
||||
except (asyncio.TimeoutError, aiohttp.errors.ClientError):
|
||||
_LOGGER.error("Timeout for google speech.")
|
||||
return (None, None)
|
||||
|
||||
finally:
|
||||
if request is not None:
|
||||
yield from request.release()
|
||||
|
||||
return ("mp3", data)
|
||||
|
||||
@staticmethod
|
||||
def _split_message_to_parts(message):
|
||||
"""Split message into single parts."""
|
||||
if len(message) <= MESSAGE_SIZE:
|
||||
return [message]
|
||||
|
||||
punc = "!()[]?.,;:"
|
||||
punc_list = [re.escape(c) for c in punc]
|
||||
pattern = '|'.join(punc_list)
|
||||
parts = re.split(pattern, message)
|
||||
|
||||
def split_by_space(fullstring):
|
||||
"""Split a string by space."""
|
||||
if len(fullstring) > MESSAGE_SIZE:
|
||||
idx = fullstring.rfind(' ', 0, MESSAGE_SIZE)
|
||||
return [fullstring[:idx]] + split_by_space(fullstring[idx:])
|
||||
else:
|
||||
return [fullstring]
|
||||
|
||||
msg_parts = []
|
||||
for part in parts:
|
||||
msg_parts += split_by_space(part)
|
||||
|
||||
return [msg for msg in msg_parts if len(msg) > 0]
|
|
@ -0,0 +1,14 @@
|
|||
say:
|
||||
description: Say some things on a media player.
|
||||
|
||||
fields:
|
||||
entity_id:
|
||||
description: Name(s) of media player entities
|
||||
example: 'media_player.floor'
|
||||
|
||||
message:
|
||||
description: Text to speak on devices
|
||||
example: 'My name is hanna'
|
||||
|
||||
clear_cache:
|
||||
description: Remove cache files and RAM cache.
|
|
@ -130,6 +130,9 @@ freesms==0.1.1
|
|||
# homeassistant.components.conversation
|
||||
fuzzywuzzy==0.14.0
|
||||
|
||||
# homeassistant.components.tts.google
|
||||
gTTS-token==1.1.1
|
||||
|
||||
# homeassistant.components.device_tracker.bluetooth_le_tracker
|
||||
# gattlib==0.20150805
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
"""The tests for tts platforms."""
|
|
@ -0,0 +1,199 @@
|
|||
"""The tests for the Google speech platform."""
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
from unittest.mock import patch
|
||||
|
||||
import homeassistant.components.tts as tts
|
||||
from homeassistant.components.media_player import (
|
||||
SERVICE_PLAY_MEDIA, ATTR_MEDIA_CONTENT_ID, DOMAIN as DOMAIN_MP)
|
||||
from homeassistant.bootstrap import setup_component
|
||||
|
||||
from tests.common import (
|
||||
get_test_home_assistant, assert_setup_component, mock_service)
|
||||
|
||||
|
||||
class TestTTSGooglePlatform(object):
|
||||
"""Test the Google speech component."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
|
||||
self.url = "http://translate.google.com/translate_tts"
|
||||
self.url_param = {
|
||||
'tl': 'en',
|
||||
'q': 'I%20person%20is%20on%20front%20of%20your%20door.',
|
||||
'tk': 5,
|
||||
'client': 'tw-ob',
|
||||
'textlen': 34,
|
||||
'total': 1,
|
||||
'idx': 0,
|
||||
'ie': 'UTF-8',
|
||||
}
|
||||
|
||||
def teardown_method(self):
|
||||
"""Stop everything that was started."""
|
||||
default_tts = self.hass.config.path(tts.DEFAULT_CACHE_DIR)
|
||||
if os.path.isdir(default_tts):
|
||||
shutil.rmtree(default_tts)
|
||||
|
||||
self.hass.stop()
|
||||
|
||||
def test_setup_component(self):
|
||||
"""Test setup component."""
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'google',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
@patch('gtts_token.gtts_token.Token.calculate_token', autospec=True,
|
||||
return_value=5)
|
||||
def test_service_say(self, mock_calculate, aioclient_mock):
|
||||
"""Test service call say."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
aioclient_mock.get(
|
||||
self.url, params=self.url_param, status=200, content=b'test')
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'google',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'google_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(".mp3") != -1
|
||||
|
||||
@patch('gtts_token.gtts_token.Token.calculate_token', autospec=True,
|
||||
return_value=5)
|
||||
def test_service_say_german(self, mock_calculate, aioclient_mock):
|
||||
"""Test service call say with german code."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
self.url_param['tl'] = 'de'
|
||||
aioclient_mock.get(
|
||||
self.url, params=self.url_param, status=200, content=b'test')
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'google',
|
||||
'language': 'de',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'google_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
@patch('gtts_token.gtts_token.Token.calculate_token', autospec=True,
|
||||
return_value=5)
|
||||
def test_service_say_error(self, mock_calculate, aioclient_mock):
|
||||
"""Test service call say with http response 400."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
aioclient_mock.get(
|
||||
self.url, params=self.url_param, status=400, content=b'test')
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'google',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'google_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 0
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
@patch('gtts_token.gtts_token.Token.calculate_token', autospec=True,
|
||||
return_value=5)
|
||||
def test_service_say_timeout(self, mock_calculate, aioclient_mock):
|
||||
"""Test service call say with http timeout."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
aioclient_mock.get(
|
||||
self.url, params=self.url_param, exc=asyncio.TimeoutError())
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'google',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'google_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 0
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
@patch('gtts_token.gtts_token.Token.calculate_token', autospec=True,
|
||||
return_value=5)
|
||||
def test_service_say_long_size(self, mock_calculate, aioclient_mock):
|
||||
"""Test service call say with a lot of text."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
self.url_param['total'] = 9
|
||||
self.url_param['q'] = "I%20person%20is%20on%20front%20of%20your%20door"
|
||||
self.url_param['textlen'] = 33
|
||||
for idx in range(0, 9):
|
||||
self.url_param['idx'] = idx
|
||||
aioclient_mock.get(
|
||||
self.url, params=self.url_param, status=200, content=b'test')
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'google',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'google_say', {
|
||||
tts.ATTR_MESSAGE: ("I person is on front of your door."
|
||||
"I person is on front of your door."
|
||||
"I person is on front of your door."
|
||||
"I person is on front of your door."
|
||||
"I person is on front of your door."
|
||||
"I person is on front of your door."
|
||||
"I person is on front of your door."
|
||||
"I person is on front of your door."
|
||||
"I person is on front of your door."),
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert len(aioclient_mock.mock_calls) == 9
|
||||
assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(".mp3") != -1
|
|
@ -0,0 +1,320 @@
|
|||
"""The tests for the TTS component."""
|
||||
import os
|
||||
import shutil
|
||||
from unittest.mock import patch
|
||||
|
||||
import requests
|
||||
|
||||
import homeassistant.components.tts as tts
|
||||
from homeassistant.components.tts.demo import DemoProvider
|
||||
from homeassistant.components.media_player import (
|
||||
SERVICE_PLAY_MEDIA, MEDIA_TYPE_MUSIC, ATTR_MEDIA_CONTENT_ID,
|
||||
ATTR_MEDIA_CONTENT_TYPE, DOMAIN as DOMAIN_MP)
|
||||
from homeassistant.bootstrap import setup_component
|
||||
|
||||
from tests.common import (
|
||||
get_test_home_assistant, assert_setup_component, mock_service)
|
||||
|
||||
|
||||
class TestTTS(object):
|
||||
"""Test the Google speech component."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
self.demo_provider = DemoProvider()
|
||||
self.default_tts_cache = self.hass.config.path(tts.DEFAULT_CACHE_DIR)
|
||||
|
||||
def teardown_method(self):
|
||||
"""Stop everything that was started."""
|
||||
if os.path.isdir(self.default_tts_cache):
|
||||
shutil.rmtree(self.default_tts_cache)
|
||||
|
||||
self.hass.stop()
|
||||
|
||||
def test_setup_component_demo(self):
|
||||
"""Setup the demo platform with defaults."""
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
assert self.hass.services.has_service(tts.DOMAIN, 'demo_say')
|
||||
assert self.hass.services.has_service(tts.DOMAIN, 'clear_cache')
|
||||
|
||||
@patch('os.mkdir', side_effect=OSError(2, "No access"))
|
||||
def test_setup_component_demo_no_access_cache_folder(self, mock_mkdir):
|
||||
"""Setup the demo platform with defaults."""
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
}
|
||||
}
|
||||
|
||||
assert not setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
assert not self.hass.services.has_service(tts.DOMAIN, 'demo_say')
|
||||
assert not self.hass.services.has_service(tts.DOMAIN, 'clear_cache')
|
||||
|
||||
def test_setup_component_and_test_service(self):
|
||||
"""Setup the demo platform and call service."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'demo_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC
|
||||
assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(
|
||||
"/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd"
|
||||
"_demo.mp3") \
|
||||
!= -1
|
||||
assert os.path.isfile(os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3"))
|
||||
|
||||
def test_setup_component_and_test_service_clear_cache(self):
|
||||
"""Setup the demo platform and call service clear cache."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'demo_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert os.path.isfile(os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3"))
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, tts.SERVICE_CLEAR_CACHE, {})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert not os.path.isfile(os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3"))
|
||||
|
||||
def test_setup_component_and_test_service_with_receive_voice(self):
|
||||
"""Setup the demo platform and call service and receive voice."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.start()
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'demo_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
_, demo_data = self.demo_provider.get_tts_audio("bla")
|
||||
assert req.status_code == 200
|
||||
assert req.content == demo_data
|
||||
|
||||
def test_setup_component_and_web_view_wrong_file(self):
|
||||
"""Setup the demo platform and receive wrong file from web."""
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.start()
|
||||
|
||||
url = ("{}/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd"
|
||||
"_demo.mp3").format(self.hass.config.api.base_url)
|
||||
|
||||
req = requests.get(url)
|
||||
assert req.status_code == 404
|
||||
|
||||
def test_setup_component_and_web_view_wrong_filename(self):
|
||||
"""Setup the demo platform and receive wrong filename from web."""
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.start()
|
||||
|
||||
url = ("{}/api/tts_proxy/265944dsk32c1b2a621be5930510bb2cd"
|
||||
"_demo.mp3").format(self.hass.config.api.base_url)
|
||||
|
||||
req = requests.get(url)
|
||||
assert req.status_code == 404
|
||||
|
||||
def test_setup_component_test_without_cache(self):
|
||||
"""Setup demo platform without cache."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
'cache': False,
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'demo_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert not os.path.isfile(os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3"))
|
||||
|
||||
def test_setup_component_test_with_cache_call_service_without_cache(self):
|
||||
"""Setup demo platform with cache and call service without cache."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
'cache': True,
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'demo_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
tts.ATTR_CACHE: False,
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert not os.path.isfile(os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3"))
|
||||
|
||||
def test_setup_component_test_with_cache_dir(self):
|
||||
"""Setup demo platform with cache and call service without cache."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
_, demo_data = self.demo_provider.get_tts_audio("bla")
|
||||
cache_file = os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")
|
||||
|
||||
os.mkdir(self.default_tts_cache)
|
||||
with open(cache_file, "wb") as voice_file:
|
||||
voice_file.write(demo_data)
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
'cache': True,
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
with patch('homeassistant.components.tts.demo.DemoProvider.'
|
||||
'get_tts_audio', return_value=None):
|
||||
self.hass.services.call(tts.DOMAIN, 'demo_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(
|
||||
"/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd"
|
||||
"_demo.mp3") \
|
||||
!= -1
|
||||
|
||||
@patch('homeassistant.components.tts.demo.DemoProvider.get_tts_audio',
|
||||
return_value=None)
|
||||
def test_setup_component_test_with_error_on_get_tts(self, tts_mock):
|
||||
"""Setup demo platform with wrong get_tts_audio."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo'
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'demo_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 0
|
||||
|
||||
def test_setup_component_load_cache_retrieve_without_mem_cache(self):
|
||||
"""Setup component and load cache and get without mem cache."""
|
||||
_, demo_data = self.demo_provider.get_tts_audio("bla")
|
||||
cache_file = os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")
|
||||
|
||||
os.mkdir(self.default_tts_cache)
|
||||
with open(cache_file, "wb") as voice_file:
|
||||
voice_file.write(demo_data)
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
'cache': True,
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.start()
|
||||
|
||||
url = ("{}/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd"
|
||||
"_demo.mp3").format(self.hass.config.api.base_url)
|
||||
|
||||
req = requests.get(url)
|
||||
assert req.status_code == 200
|
||||
assert req.content == demo_data
|
|
@ -23,6 +23,7 @@ class AiohttpClientMocker:
|
|||
content=None,
|
||||
json=None,
|
||||
params=None,
|
||||
headers=None,
|
||||
exc=None):
|
||||
"""Mock a request."""
|
||||
if json:
|
||||
|
@ -65,8 +66,8 @@ class AiohttpClientMocker:
|
|||
return len(self.mock_calls)
|
||||
|
||||
@asyncio.coroutine
|
||||
def match_request(self, method, url, *, auth=None, params=None): \
|
||||
# pylint: disable=unused-variable
|
||||
def match_request(self, method, url, *, auth=None, params=None,
|
||||
headers=None): # pylint: disable=unused-variable
|
||||
"""Match a request against pre-registered requests."""
|
||||
for response in self._mocks:
|
||||
if response.match_request(method, url, params):
|
||||
|
@ -76,8 +77,8 @@ class AiohttpClientMocker:
|
|||
raise self.exc
|
||||
return response
|
||||
|
||||
assert False, "No mock registered for {} {}".format(method.upper(),
|
||||
url)
|
||||
assert False, "No mock registered for {} {} {}".format(method.upper(),
|
||||
url, params)
|
||||
|
||||
|
||||
class AiohttpClientMockResponse:
|
||||
|
|
Loading…
Reference in New Issue