diff --git a/homeassistant/components/media_player/__init__.py b/homeassistant/components/media_player/__init__.py index fa2ecee4337..3dea75df874 100644 --- a/homeassistant/components/media_player/__init__.py +++ b/homeassistant/components/media_player/__init__.py @@ -162,7 +162,7 @@ MEDIA_PLAYER_MEDIA_SEEK_SCHEMA = MEDIA_PLAYER_SCHEMA.extend({ MEDIA_PLAYER_PLAY_MEDIA_SCHEMA = MEDIA_PLAYER_SCHEMA.extend({ vol.Required(ATTR_MEDIA_CONTENT_TYPE): cv.string, vol.Required(ATTR_MEDIA_CONTENT_ID): cv.string, - ATTR_MEDIA_ENQUEUE: cv.boolean, + vol.Optional(ATTR_MEDIA_ENQUEUE): cv.boolean, }) MEDIA_PLAYER_SELECT_SOURCE_SCHEMA = MEDIA_PLAYER_SCHEMA.extend({ diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py new file mode 100644 index 00000000000..0e75de88cc5 --- /dev/null +++ b/homeassistant/components/tts/__init__.py @@ -0,0 +1,421 @@ +""" +Provide functionality to TTS. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/components/tts/ +""" +import asyncio +import logging +import hashlib +import mimetypes +import os +import re + +from aiohttp import web +import voluptuous as vol + +from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.bootstrap import async_prepare_setup_platform +from homeassistant.core import callback +from homeassistant.config import load_yaml_config_file +from homeassistant.components.http import HomeAssistantView +from homeassistant.components.media_player import ( + SERVICE_PLAY_MEDIA, MEDIA_TYPE_MUSIC, ATTR_MEDIA_CONTENT_ID, + ATTR_MEDIA_CONTENT_TYPE, DOMAIN as DOMAIN_MP) +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import config_per_platform +import homeassistant.helpers.config_validation as cv + +DOMAIN = 'tts' +DEPENDENCIES = ['http'] + +_LOGGER = logging.getLogger(__name__) + +MEM_CACHE_FILENAME = 'filename' +MEM_CACHE_VOICE = 'voice' + +CONF_LANG = 'language' +CONF_CACHE = 'cache' +CONF_CACHE_DIR = 'cache_dir' +CONF_TIME_MEMORY = 'time_memory' + +DEFAULT_CACHE = True +DEFAULT_CACHE_DIR = "tts" +DEFAULT_LANG = 'en' +DEFAULT_TIME_MEMORY = 300 + +SERVICE_SAY = 'say' +SERVICE_CLEAR_CACHE = 'clear_cache' + +ATTR_MESSAGE = 'message' +ATTR_CACHE = 'cache' + +_RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([a-z]+)\.[a-z0-9]{3,4}") + +PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend({ + vol.Optional(CONF_LANG, default=DEFAULT_LANG): cv.string, + vol.Optional(CONF_CACHE, default=DEFAULT_CACHE): cv.boolean, + vol.Optional(CONF_CACHE_DIR, default=DEFAULT_CACHE_DIR): cv.string, + vol.Optional(CONF_TIME_MEMORY, default=DEFAULT_TIME_MEMORY): + vol.All(vol.Coerce(int), vol.Range(min=60, max=57600)), +}) + + +SCHEMA_SERVICE_SAY = vol.Schema({ + vol.Required(ATTR_MESSAGE): cv.string, + vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, + vol.Optional(ATTR_CACHE): cv.boolean, +}) + +SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) + + +@asyncio.coroutine +def async_setup(hass, config): + """Setup TTS.""" + tts = SpeechManager(hass) + + try: + conf = config[DOMAIN][0] if len(config.get(DOMAIN, [])) > 0 else {} + use_cache = conf.get(CONF_CACHE, DEFAULT_CACHE) + cache_dir = conf.get(CONF_CACHE_DIR, DEFAULT_CACHE_DIR) + time_memory = conf.get(CONF_TIME_MEMORY, DEFAULT_LANG) + + yield from tts.async_init_cache(use_cache, cache_dir, time_memory) + except (HomeAssistantError, KeyError) as err: + _LOGGER.error("Error on cache init %s", err) + return False + + hass.http.register_view(TextToSpeechView(tts)) + + descriptions = yield from hass.loop.run_in_executor( + None, load_yaml_config_file, + os.path.join(os.path.dirname(__file__), 'services.yaml')) + + @asyncio.coroutine + def async_setup_platform(p_type, p_config, disc_info=None): + """Setup a tts platform.""" + platform = yield from async_prepare_setup_platform( + hass, config, DOMAIN, p_type) + if platform is None: + return + + try: + if hasattr(platform, 'async_get_engine'): + provider = yield from platform.async_get_engine( + hass, p_config) + else: + provider = yield from hass.loop.run_in_executor( + None, platform.get_engine, hass, p_config) + + if provider is None: + _LOGGER.error('Error setting up platform %s', p_type) + return + + tts.async_register_engine(p_type, provider, p_config) + except Exception: # pylint: disable=broad-except + _LOGGER.exception('Error setting up platform %s', p_type) + return + + @asyncio.coroutine + def async_say_handle(service): + """Service handle for say.""" + entity_ids = service.data.get(ATTR_ENTITY_ID) + message = service.data.get(ATTR_MESSAGE) + cache = service.data.get(ATTR_CACHE) + + try: + url = yield from tts.async_get_url( + p_type, message, cache=cache) + except HomeAssistantError as err: + _LOGGER.error("Error on init tts: %s", err) + return + + data = { + ATTR_MEDIA_CONTENT_ID: url, + ATTR_MEDIA_CONTENT_TYPE: MEDIA_TYPE_MUSIC, + } + + if entity_ids: + data[ATTR_ENTITY_ID] = entity_ids + + yield from hass.services.async_call( + DOMAIN_MP, SERVICE_PLAY_MEDIA, data, blocking=True) + + hass.services.async_register( + DOMAIN, "{}_{}".format(p_type, SERVICE_SAY), async_say_handle, + descriptions.get(SERVICE_SAY), schema=SCHEMA_SERVICE_SAY) + + setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config + in config_per_platform(config, DOMAIN)] + + if setup_tasks: + yield from asyncio.wait(setup_tasks, loop=hass.loop) + + @asyncio.coroutine + def async_clear_cache_handle(service): + """Handle clear cache service call.""" + yield from tts.async_clear_cache() + + hass.services.async_register( + DOMAIN, SERVICE_CLEAR_CACHE, async_clear_cache_handle, + descriptions.get(SERVICE_CLEAR_CACHE), schema=SERVICE_CLEAR_CACHE) + + return True + + +class SpeechManager(object): + """Representation of a speech store.""" + + def __init__(self, hass): + """Initialize a speech store.""" + self.hass = hass + self.providers = {} + + self.use_cache = True + self.cache_dir = None + self.time_memory = None + self.file_cache = {} + self.mem_cache = {} + + @asyncio.coroutine + def async_init_cache(self, use_cache, cache_dir, time_memory): + """Init config folder and load file cache.""" + self.use_cache = use_cache + self.time_memory = time_memory + + def init_tts_cache_dir(cache_dir): + """Init cache folder.""" + if not os.path.isabs(cache_dir): + cache_dir = self.hass.config.path(cache_dir) + if not os.path.isdir(cache_dir): + _LOGGER.info("Create cache dir %s.", cache_dir) + os.mkdir(cache_dir) + return cache_dir + + try: + self.cache_dir = yield from self.hass.loop.run_in_executor( + None, init_tts_cache_dir, cache_dir) + except OSError as err: + raise HomeAssistantError( + "Can't init cache dir {}".format(err)) + + def get_cache_files(): + """Return a dict of given engine files.""" + cache = {} + + folder_data = os.listdir(self.cache_dir) + for file_data in folder_data: + record = _RE_VOICE_FILE.match(file_data) + if record: + key = "{}_{}".format(record.group(1), record.group(2)) + cache[key.lower()] = file_data.lower() + return cache + + try: + cache_files = yield from self.hass.loop.run_in_executor( + None, get_cache_files) + except OSError as err: + raise HomeAssistantError( + "Can't read cache dir {}".format(err)) + + if cache_files: + self.file_cache.update(cache_files) + + @asyncio.coroutine + def async_clear_cache(self): + """Read file cache and delete files.""" + self.mem_cache = {} + + def remove_files(): + """Remove files from filesystem.""" + for _, filename in self.file_cache.items(): + try: + os.remove(os.path.join(self.cache_dir), filename) + except OSError: + pass + + yield from self.hass.loop.run_in_executor(None, remove_files) + self.file_cache = {} + + @callback + def async_register_engine(self, engine, provider, config): + """Register a TTS provider.""" + provider.hass = self.hass + provider.language = config.get(CONF_LANG) + self.providers[engine] = provider + + @asyncio.coroutine + def async_get_url(self, engine, message, cache=None): + """Get URL for play message. + + This method is a coroutine. + """ + msg_hash = hashlib.sha1(bytes(message, 'utf-8')).hexdigest() + key = ("{}_{}".format(msg_hash, engine)).lower() + use_cache = cache if cache is not None else self.use_cache + + # is speech allready in memory + if key in self.mem_cache: + filename = self.mem_cache[key][MEM_CACHE_FILENAME] + # is file store in file cache + elif use_cache and key in self.file_cache: + filename = self.file_cache[key] + self.hass.async_add_job(self.async_file_to_mem(engine, key)) + # load speech from provider into memory + else: + filename = yield from self.async_get_tts_audio( + engine, key, message, use_cache) + + return "{}/api/tts_proxy/{}".format( + self.hass.config.api.base_url, filename) + + @asyncio.coroutine + def async_get_tts_audio(self, engine, key, message, cache): + """Receive TTS and store for view in cache. + + This method is a coroutine. + """ + provider = self.providers[engine] + extension, data = yield from provider.async_get_tts_audio(message) + + if data is None or extension is None: + raise HomeAssistantError( + "No TTS from {} for '{}'".format(engine, message)) + + # create file infos + filename = ("{}.{}".format(key, extension)).lower() + + # save to memory + self._async_store_to_memcache(key, filename, data) + + if cache: + self.hass.async_add_job( + self.async_save_tts_audio(key, filename, data)) + + return filename + + @asyncio.coroutine + def async_save_tts_audio(self, key, filename, data): + """Store voice data to file and file_cache. + + This method is a coroutine. + """ + voice_file = os.path.join(self.cache_dir, filename) + + def save_speech(): + """Store speech to filesystem.""" + with open(voice_file, 'wb') as speech: + speech.write(data) + + try: + yield from self.hass.loop.run_in_executor(None, save_speech) + self.file_cache[key] = filename + except OSError: + _LOGGER.error("Can't write %s", filename) + + @asyncio.coroutine + def async_file_to_mem(self, engine, key): + """Load voice from file cache into memory. + + This method is a coroutine. + """ + filename = self.file_cache.get(key) + if not filename: + raise HomeAssistantError("Key {} not in file cache!".format(key)) + + voice_file = os.path.join(self.cache_dir, filename) + + def load_speech(): + """Load a speech from filesystem.""" + with open(voice_file, 'rb') as speech: + return speech.read() + + try: + data = yield from self.hass.loop.run_in_executor(None, load_speech) + except OSError: + raise HomeAssistantError("Can't read {}".format(voice_file)) + + self._async_store_to_memcache(key, filename, data) + + @callback + def _async_store_to_memcache(self, key, filename, data): + """Store data to memcache and set timer to remove it.""" + self.mem_cache[key] = { + MEM_CACHE_FILENAME: filename, + MEM_CACHE_VOICE: data, + } + + @callback + def async_remove_from_mem(): + """Cleanup memcache.""" + self.mem_cache.pop(key) + + self.hass.loop.call_later(self.time_memory, async_remove_from_mem) + + @asyncio.coroutine + def async_read_tts(self, filename): + """Read a voice file and return binary. + + This method is a coroutine. + """ + record = _RE_VOICE_FILE.match(filename.lower()) + if not record: + raise HomeAssistantError("Wrong tts file format!") + + key = "{}_{}".format(record.group(1), record.group(2)) + + if key not in self.mem_cache: + if key not in self.file_cache: + raise HomeAssistantError("%s not in cache!", key) + engine = record.group(2) + yield from self.async_file_to_mem(engine, key) + + content, _ = mimetypes.guess_type(filename) + return (content, self.mem_cache[key][MEM_CACHE_VOICE]) + + +class Provider(object): + """Represent a single provider.""" + + hass = None + language = DEFAULT_LANG + + def get_tts_audio(self, message): + """Load tts audio file from provider.""" + raise NotImplementedError() + + @asyncio.coroutine + def async_get_tts_audio(self, message): + """Load tts audio file from provider. + + Return a tuple of file extension and data as bytes. + + This method is a coroutine. + """ + extension, data = yield from self.hass.loop.run_in_executor( + None, self.get_tts_audio, message) + return (extension, data) + + +class TextToSpeechView(HomeAssistantView): + """TTS view to serve an speech audio.""" + + requires_auth = False + url = "/api/tts_proxy/{filename}" + name = "api:tts:speech" + + def __init__(self, tts): + """Initialize a tts view.""" + self.tts = tts + + @asyncio.coroutine + def get(self, request, filename): + """Start a get request.""" + try: + content, data = yield from self.tts.async_read_tts(filename) + except HomeAssistantError as err: + _LOGGER.error("Error on load tts: %s", err) + return web.Response(status=404) + + return web.Response(body=data, content_type=content) diff --git a/homeassistant/components/tts/demo.mp3 b/homeassistant/components/tts/demo.mp3 new file mode 100644 index 00000000000..f34241c7698 Binary files /dev/null and b/homeassistant/components/tts/demo.mp3 differ diff --git a/homeassistant/components/tts/demo.py b/homeassistant/components/tts/demo.py new file mode 100644 index 00000000000..a63bd6373ea --- /dev/null +++ b/homeassistant/components/tts/demo.py @@ -0,0 +1,29 @@ +""" +Support for the demo speech service. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/components/demo/ +""" +import os + +from homeassistant.components.tts import Provider + + +def get_engine(hass, config): + """Setup Demo speech component.""" + return DemoProvider() + + +class DemoProvider(Provider): + """Demo speech api provider.""" + + def get_tts_audio(self, message): + """Load TTS from demo.""" + filename = os.path.join(os.path.dirname(__file__), "demo.mp3") + try: + with open(filename, 'rb') as voice: + data = voice.read() + except OSError: + return + + return ("mp3", data) diff --git a/homeassistant/components/tts/google.py b/homeassistant/components/tts/google.py new file mode 100644 index 00000000000..b271b2468d1 --- /dev/null +++ b/homeassistant/components/tts/google.py @@ -0,0 +1,117 @@ +""" +Support for the google speech service. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/components/tts/google/ +""" +import asyncio +import logging +import re + +import aiohttp +import async_timeout +import yarl + +from homeassistant.components.tts import Provider +from homeassistant.helpers.aiohttp_client import async_get_clientsession + +REQUIREMENTS = ["gTTS-token==1.1.1"] + +_LOGGER = logging.getLogger(__name__) + +GOOGLE_SPEECH_URL = "http://translate.google.com/translate_tts" +MESSAGE_SIZE = 148 + + +@asyncio.coroutine +def async_get_engine(hass, config): + """Setup Google speech component.""" + return GoogleProvider(hass) + + +class GoogleProvider(Provider): + """Google speech api provider.""" + + def __init__(self, hass): + """Init Google TTS service.""" + self.hass = hass + self.headers = { + 'Referer': "http://translate.google.com/", + 'User-Agent': ("Mozilla/5.0 (Windows NT 10.0; WOW64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/47.0.2526.106 Safari/537.36") + } + + @asyncio.coroutine + def async_get_tts_audio(self, message): + """Load TTS from google.""" + from gtts_token import gtts_token + + token = gtts_token.Token() + websession = async_get_clientsession(self.hass) + message_parts = self._split_message_to_parts(message) + + data = b'' + for idx, part in enumerate(message_parts): + part_token = yield from self.hass.loop.run_in_executor( + None, token.calculate_token, part) + + url_param = { + 'ie': 'UTF-8', + 'tl': self.language, + 'q': yarl.quote(part), + 'tk': part_token, + 'total': len(message_parts), + 'idx': idx, + 'client': 'tw-ob', + 'textlen': len(part), + } + + request = None + try: + with async_timeout.timeout(10, loop=self.hass.loop): + request = yield from websession.get( + GOOGLE_SPEECH_URL, params=url_param, + headers=self.headers + ) + + if request.status != 200: + _LOGGER.error("Error %d on load url %s", request.code, + request.url) + return (None, None) + data += yield from request.read() + + except (asyncio.TimeoutError, aiohttp.errors.ClientError): + _LOGGER.error("Timeout for google speech.") + return (None, None) + + finally: + if request is not None: + yield from request.release() + + return ("mp3", data) + + @staticmethod + def _split_message_to_parts(message): + """Split message into single parts.""" + if len(message) <= MESSAGE_SIZE: + return [message] + + punc = "!()[]?.,;:" + punc_list = [re.escape(c) for c in punc] + pattern = '|'.join(punc_list) + parts = re.split(pattern, message) + + def split_by_space(fullstring): + """Split a string by space.""" + if len(fullstring) > MESSAGE_SIZE: + idx = fullstring.rfind(' ', 0, MESSAGE_SIZE) + return [fullstring[:idx]] + split_by_space(fullstring[idx:]) + else: + return [fullstring] + + msg_parts = [] + for part in parts: + msg_parts += split_by_space(part) + + return [msg for msg in msg_parts if len(msg) > 0] diff --git a/homeassistant/components/tts/services.yaml b/homeassistant/components/tts/services.yaml new file mode 100644 index 00000000000..aba1334da87 --- /dev/null +++ b/homeassistant/components/tts/services.yaml @@ -0,0 +1,14 @@ +say: + description: Say some things on a media player. + + fields: + entity_id: + description: Name(s) of media player entities + example: 'media_player.floor' + + message: + description: Text to speak on devices + example: 'My name is hanna' + +clear_cache: + description: Remove cache files and RAM cache. diff --git a/requirements_all.txt b/requirements_all.txt index 799b2b29b11..c85e3285e7e 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -130,6 +130,9 @@ freesms==0.1.1 # homeassistant.components.conversation fuzzywuzzy==0.14.0 +# homeassistant.components.tts.google +gTTS-token==1.1.1 + # homeassistant.components.device_tracker.bluetooth_le_tracker # gattlib==0.20150805 diff --git a/tests/components/tts/__init__.py b/tests/components/tts/__init__.py new file mode 100644 index 00000000000..f5eb0731409 --- /dev/null +++ b/tests/components/tts/__init__.py @@ -0,0 +1 @@ +"""The tests for tts platforms.""" diff --git a/tests/components/tts/test_google.py b/tests/components/tts/test_google.py new file mode 100644 index 00000000000..623a96f1dfb --- /dev/null +++ b/tests/components/tts/test_google.py @@ -0,0 +1,199 @@ +"""The tests for the Google speech platform.""" +import asyncio +import os +import shutil +from unittest.mock import patch + +import homeassistant.components.tts as tts +from homeassistant.components.media_player import ( + SERVICE_PLAY_MEDIA, ATTR_MEDIA_CONTENT_ID, DOMAIN as DOMAIN_MP) +from homeassistant.bootstrap import setup_component + +from tests.common import ( + get_test_home_assistant, assert_setup_component, mock_service) + + +class TestTTSGooglePlatform(object): + """Test the Google speech component.""" + + def setup_method(self): + """Setup things to be run when tests are started.""" + self.hass = get_test_home_assistant() + + self.url = "http://translate.google.com/translate_tts" + self.url_param = { + 'tl': 'en', + 'q': 'I%20person%20is%20on%20front%20of%20your%20door.', + 'tk': 5, + 'client': 'tw-ob', + 'textlen': 34, + 'total': 1, + 'idx': 0, + 'ie': 'UTF-8', + } + + def teardown_method(self): + """Stop everything that was started.""" + default_tts = self.hass.config.path(tts.DEFAULT_CACHE_DIR) + if os.path.isdir(default_tts): + shutil.rmtree(default_tts) + + self.hass.stop() + + def test_setup_component(self): + """Test setup component.""" + config = { + tts.DOMAIN: { + 'platform': 'google', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, + return_value=5) + def test_service_say(self, mock_calculate, aioclient_mock): + """Test service call say.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + aioclient_mock.get( + self.url, params=self.url_param, status=200, content=b'test') + + config = { + tts.DOMAIN: { + 'platform': 'google', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'google_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert len(aioclient_mock.mock_calls) == 1 + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(".mp3") != -1 + + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, + return_value=5) + def test_service_say_german(self, mock_calculate, aioclient_mock): + """Test service call say with german code.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + self.url_param['tl'] = 'de' + aioclient_mock.get( + self.url, params=self.url_param, status=200, content=b'test') + + config = { + tts.DOMAIN: { + 'platform': 'google', + 'language': 'de', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'google_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert len(aioclient_mock.mock_calls) == 1 + + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, + return_value=5) + def test_service_say_error(self, mock_calculate, aioclient_mock): + """Test service call say with http response 400.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + aioclient_mock.get( + self.url, params=self.url_param, status=400, content=b'test') + + config = { + tts.DOMAIN: { + 'platform': 'google', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'google_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 0 + assert len(aioclient_mock.mock_calls) == 1 + + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, + return_value=5) + def test_service_say_timeout(self, mock_calculate, aioclient_mock): + """Test service call say with http timeout.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + aioclient_mock.get( + self.url, params=self.url_param, exc=asyncio.TimeoutError()) + + config = { + tts.DOMAIN: { + 'platform': 'google', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'google_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 0 + assert len(aioclient_mock.mock_calls) == 1 + + @patch('gtts_token.gtts_token.Token.calculate_token', autospec=True, + return_value=5) + def test_service_say_long_size(self, mock_calculate, aioclient_mock): + """Test service call say with a lot of text.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + self.url_param['total'] = 9 + self.url_param['q'] = "I%20person%20is%20on%20front%20of%20your%20door" + self.url_param['textlen'] = 33 + for idx in range(0, 9): + self.url_param['idx'] = idx + aioclient_mock.get( + self.url, params=self.url_param, status=200, content=b'test') + + config = { + tts.DOMAIN: { + 'platform': 'google', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'google_say', { + tts.ATTR_MESSAGE: ("I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door." + "I person is on front of your door."), + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert len(aioclient_mock.mock_calls) == 9 + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(".mp3") != -1 diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py new file mode 100644 index 00000000000..fbdbddb8db5 --- /dev/null +++ b/tests/components/tts/test_init.py @@ -0,0 +1,320 @@ +"""The tests for the TTS component.""" +import os +import shutil +from unittest.mock import patch + +import requests + +import homeassistant.components.tts as tts +from homeassistant.components.tts.demo import DemoProvider +from homeassistant.components.media_player import ( + SERVICE_PLAY_MEDIA, MEDIA_TYPE_MUSIC, ATTR_MEDIA_CONTENT_ID, + ATTR_MEDIA_CONTENT_TYPE, DOMAIN as DOMAIN_MP) +from homeassistant.bootstrap import setup_component + +from tests.common import ( + get_test_home_assistant, assert_setup_component, mock_service) + + +class TestTTS(object): + """Test the Google speech component.""" + + def setup_method(self): + """Setup things to be run when tests are started.""" + self.hass = get_test_home_assistant() + self.demo_provider = DemoProvider() + self.default_tts_cache = self.hass.config.path(tts.DEFAULT_CACHE_DIR) + + def teardown_method(self): + """Stop everything that was started.""" + if os.path.isdir(self.default_tts_cache): + shutil.rmtree(self.default_tts_cache) + + self.hass.stop() + + def test_setup_component_demo(self): + """Setup the demo platform with defaults.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + assert self.hass.services.has_service(tts.DOMAIN, 'demo_say') + assert self.hass.services.has_service(tts.DOMAIN, 'clear_cache') + + @patch('os.mkdir', side_effect=OSError(2, "No access")) + def test_setup_component_demo_no_access_cache_folder(self, mock_mkdir): + """Setup the demo platform with defaults.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + assert not setup_component(self.hass, tts.DOMAIN, config) + + assert not self.hass.services.has_service(tts.DOMAIN, 'demo_say') + assert not self.hass.services.has_service(tts.DOMAIN, 'clear_cache') + + def test_setup_component_and_test_service(self): + """Setup the demo platform and call service.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( + "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_demo.mp3") \ + != -1 + assert os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + + def test_setup_component_and_test_service_clear_cache(self): + """Setup the demo platform and call service clear cache.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + + self.hass.services.call(tts.DOMAIN, tts.SERVICE_CLEAR_CACHE, {}) + self.hass.block_till_done() + + assert not os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + + def test_setup_component_and_test_service_with_receive_voice(self): + """Setup the demo platform and call service and receive voice.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.start() + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID]) + _, demo_data = self.demo_provider.get_tts_audio("bla") + assert req.status_code == 200 + assert req.content == demo_data + + def test_setup_component_and_web_view_wrong_file(self): + """Setup the demo platform and receive wrong file from web.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.start() + + url = ("{}/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_demo.mp3").format(self.hass.config.api.base_url) + + req = requests.get(url) + assert req.status_code == 404 + + def test_setup_component_and_web_view_wrong_filename(self): + """Setup the demo platform and receive wrong filename from web.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.start() + + url = ("{}/api/tts_proxy/265944dsk32c1b2a621be5930510bb2cd" + "_demo.mp3").format(self.hass.config.api.base_url) + + req = requests.get(url) + assert req.status_code == 404 + + def test_setup_component_test_without_cache(self): + """Setup demo platform without cache.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + 'cache': False, + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert not os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + + def test_setup_component_test_with_cache_call_service_without_cache(self): + """Setup demo platform with cache and call service without cache.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + 'cache': True, + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + tts.ATTR_CACHE: False, + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert not os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3")) + + def test_setup_component_test_with_cache_dir(self): + """Setup demo platform with cache and call service without cache.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + _, demo_data = self.demo_provider.get_tts_audio("bla") + cache_file = os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3") + + os.mkdir(self.default_tts_cache) + with open(cache_file, "wb") as voice_file: + voice_file.write(demo_data) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + 'cache': True, + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + with patch('homeassistant.components.tts.demo.DemoProvider.' + 'get_tts_audio', return_value=None): + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 1 + assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( + "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_demo.mp3") \ + != -1 + + @patch('homeassistant.components.tts.demo.DemoProvider.get_tts_audio', + return_value=None) + def test_setup_component_test_with_error_on_get_tts(self, tts_mock): + """Setup demo platform with wrong get_tts_audio.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo' + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + }) + self.hass.block_till_done() + + assert len(calls) == 0 + + def test_setup_component_load_cache_retrieve_without_mem_cache(self): + """Setup component and load cache and get without mem cache.""" + _, demo_data = self.demo_provider.get_tts_audio("bla") + cache_file = os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_demo.mp3") + + os.mkdir(self.default_tts_cache) + with open(cache_file, "wb") as voice_file: + voice_file.write(demo_data) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + 'cache': True, + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.start() + + url = ("{}/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" + "_demo.mp3").format(self.hass.config.api.base_url) + + req = requests.get(url) + assert req.status_code == 200 + assert req.content == demo_data diff --git a/tests/test_util/aiohttp.py b/tests/test_util/aiohttp.py index d6f0c80b435..4abf43a6e42 100644 --- a/tests/test_util/aiohttp.py +++ b/tests/test_util/aiohttp.py @@ -23,6 +23,7 @@ class AiohttpClientMocker: content=None, json=None, params=None, + headers=None, exc=None): """Mock a request.""" if json: @@ -65,8 +66,8 @@ class AiohttpClientMocker: return len(self.mock_calls) @asyncio.coroutine - def match_request(self, method, url, *, auth=None, params=None): \ - # pylint: disable=unused-variable + def match_request(self, method, url, *, auth=None, params=None, + headers=None): # pylint: disable=unused-variable """Match a request against pre-registered requests.""" for response in self._mocks: if response.match_request(method, url, params): @@ -76,8 +77,8 @@ class AiohttpClientMocker: raise self.exc return response - assert False, "No mock registered for {} {}".format(method.upper(), - url) + assert False, "No mock registered for {} {} {}".format(method.upper(), + url, params) class AiohttpClientMockResponse: