Clean up Google Assistant (#11375)

* Clean up Google Assistant

* Fix tests
pull/11405/head
Paulus Schoutsen 2017-12-31 15:04:49 -08:00 committed by GitHub
parent fcbf7abdaa
commit fc8b25a71f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 165 additions and 183 deletions

View File

@ -15,7 +15,6 @@ import voluptuous as vol
# Typing imports
# pylint: disable=using-constant-test,unused-import,ungrouped-imports
# if False:
from homeassistant.core import HomeAssistant # NOQA
from typing import Dict, Any # NOQA
@ -26,12 +25,12 @@ from homeassistant.loader import bind_hass
from .const import (
DOMAIN, CONF_PROJECT_ID, CONF_CLIENT_ID, CONF_ACCESS_TOKEN,
CONF_EXPOSE_BY_DEFAULT, CONF_EXPOSED_DOMAINS,
CONF_AGENT_USER_ID, CONF_API_KEY,
CONF_EXPOSE_BY_DEFAULT, DEFAULT_EXPOSE_BY_DEFAULT, CONF_EXPOSED_DOMAINS,
DEFAULT_EXPOSED_DOMAINS, CONF_AGENT_USER_ID, CONF_API_KEY,
SERVICE_REQUEST_SYNC, REQUEST_SYNC_BASE_URL
)
from .auth import GoogleAssistantAuthView
from .http import GoogleAssistantView
from .http import async_register_http
_LOGGER = logging.getLogger(__name__)
@ -45,8 +44,10 @@ CONFIG_SCHEMA = vol.Schema(
vol.Required(CONF_PROJECT_ID): cv.string,
vol.Required(CONF_CLIENT_ID): cv.string,
vol.Required(CONF_ACCESS_TOKEN): cv.string,
vol.Optional(CONF_EXPOSE_BY_DEFAULT): cv.boolean,
vol.Optional(CONF_EXPOSED_DOMAINS): cv.ensure_list,
vol.Optional(CONF_EXPOSE_BY_DEFAULT,
default=DEFAULT_EXPOSE_BY_DEFAULT): cv.boolean,
vol.Optional(CONF_EXPOSED_DOMAINS,
default=DEFAULT_EXPOSED_DOMAINS): cv.ensure_list,
vol.Optional(CONF_AGENT_USER_ID,
default=DEFAULT_AGENT_USER_ID): cv.string,
vol.Optional(CONF_API_KEY): cv.string
@ -73,7 +74,7 @@ def async_setup(hass: HomeAssistant, yaml_config: Dict[str, Any]):
os.path.dirname(__file__), 'services.yaml')
)
hass.http.register_view(GoogleAssistantAuthView(hass, config))
hass.http.register_view(GoogleAssistantView(hass, config))
async_register_http(hass, config)
@asyncio.coroutine
def request_sync_service_handler(call):
@ -94,7 +95,7 @@ def async_setup(hass: HomeAssistant, yaml_config: Dict[str, Any]):
except (asyncio.TimeoutError, aiohttp.ClientError):
_LOGGER.error("Could not contact Google for request_sync")
# Register service only if api key is provided
# Register service only if api key is provided
if api_key is not None:
hass.services.async_register(
DOMAIN, SERVICE_REQUEST_SYNC, request_sync_service_handler,

View File

@ -7,53 +7,39 @@ https://home-assistant.io/components/google_assistant/
import asyncio
import logging
from typing import Any, Dict # NOQA
from aiohttp.hdrs import AUTHORIZATION
from aiohttp.web import Request, Response # NOQA
from homeassistant.const import HTTP_UNAUTHORIZED
# Typing imports
# pylint: disable=using-constant-test,unused-import,ungrouped-imports
# if False:
from homeassistant.components.http import HomeAssistantView
from homeassistant.const import HTTP_BAD_REQUEST, HTTP_UNAUTHORIZED
from homeassistant.core import HomeAssistant # NOQA
from homeassistant.core import HomeAssistant, callback # NOQA
from homeassistant.helpers.entity import Entity # NOQA
from .const import (
GOOGLE_ASSISTANT_API_ENDPOINT,
CONF_ACCESS_TOKEN,
DEFAULT_EXPOSE_BY_DEFAULT,
DEFAULT_EXPOSED_DOMAINS,
CONF_EXPOSE_BY_DEFAULT,
CONF_EXPOSED_DOMAINS,
ATTR_GOOGLE_ASSISTANT,
CONF_AGENT_USER_ID
)
from .smart_home import entity_to_device, query_device, determine_service
from .smart_home import async_handle_message, Config
_LOGGER = logging.getLogger(__name__)
class GoogleAssistantView(HomeAssistantView):
"""Handle Google Assistant requests."""
@callback
def async_register_http(hass, cfg):
"""Register HTTP views for Google Assistant."""
access_token = cfg.get(CONF_ACCESS_TOKEN)
expose_by_default = cfg.get(CONF_EXPOSE_BY_DEFAULT)
exposed_domains = cfg.get(CONF_EXPOSED_DOMAINS)
agent_user_id = cfg.get(CONF_AGENT_USER_ID)
url = GOOGLE_ASSISTANT_API_ENDPOINT
name = 'api:google_assistant'
requires_auth = False # Uses access token from oauth flow
def __init__(self, hass: HomeAssistant, cfg: Dict[str, Any]) -> None:
"""Initialize Google Assistant view."""
super().__init__()
self.access_token = cfg.get(CONF_ACCESS_TOKEN)
self.expose_by_default = cfg.get(CONF_EXPOSE_BY_DEFAULT,
DEFAULT_EXPOSE_BY_DEFAULT)
self.exposed_domains = cfg.get(CONF_EXPOSED_DOMAINS,
DEFAULT_EXPOSED_DOMAINS)
self.agent_user_id = cfg.get(CONF_AGENT_USER_ID)
def is_entity_exposed(self, entity) -> bool:
def is_exposed(entity) -> bool:
"""Determine if an entity should be exposed to Google Assistant."""
if entity.attributes.get('view') is not None:
# Ignore entities that are views
@ -63,7 +49,7 @@ class GoogleAssistantView(HomeAssistantView):
explicit_expose = entity.attributes.get(ATTR_GOOGLE_ASSISTANT, None)
domain_exposed_by_default = \
self.expose_by_default and domain in self.exposed_domains
expose_by_default and domain in exposed_domains
# Expose an entity if the entity's domain is exposed by default and
# the configuration doesn't explicitly exclude it from being
@ -73,79 +59,22 @@ class GoogleAssistantView(HomeAssistantView):
return is_default_exposed or explicit_expose
@asyncio.coroutine
def handle_sync(self, hass: HomeAssistant, request_id: str):
"""Handle SYNC action."""
devices = []
for entity in hass.states.async_all():
if not self.is_entity_exposed(entity):
continue
gass_config = Config(is_exposed, agent_user_id)
hass.http.register_view(
GoogleAssistantView(access_token, gass_config))
device = entity_to_device(entity, hass.config.units)
if device is None:
_LOGGER.warning("No mapping for %s domain", entity.domain)
continue
devices.append(device)
class GoogleAssistantView(HomeAssistantView):
"""Handle Google Assistant requests."""
return self.json(
_make_actions_response(request_id,
{'agentUserId': self.agent_user_id,
'devices': devices}))
url = GOOGLE_ASSISTANT_API_ENDPOINT
name = 'api:google_assistant'
requires_auth = False # Uses access token from oauth flow
@asyncio.coroutine
def handle_query(self,
hass: HomeAssistant,
request_id: str,
requested_devices: list):
"""Handle the QUERY action."""
devices = {}
for device in requested_devices:
devid = device.get('id')
# In theory this should never happpen
if not devid:
_LOGGER.error('Device missing ID: %s', device)
continue
state = hass.states.get(devid)
if not state:
# If we can't find a state, the device is offline
devices[devid] = {'online': False}
devices[devid] = query_device(state, hass.config.units)
return self.json(
_make_actions_response(request_id, {'devices': devices}))
@asyncio.coroutine
def handle_execute(self,
hass: HomeAssistant,
request_id: str,
requested_commands: list):
"""Handle the EXECUTE action."""
commands = []
for command in requested_commands:
ent_ids = [ent.get('id') for ent in command.get('devices', [])]
for execution in command.get('execution'):
for eid in ent_ids:
success = False
domain = eid.split('.')[0]
(service, service_data) = determine_service(
eid, execution.get('command'), execution.get('params'),
hass.config.units)
if domain == "group":
domain = "homeassistant"
success = yield from hass.services.async_call(
domain, service, service_data, blocking=True)
result = {"ids": [eid], "states": {}}
if success:
result['status'] = 'SUCCESS'
else:
result['status'] = 'ERROR'
commands.append(result)
return self.json(
_make_actions_response(request_id, {'commands': commands}))
def __init__(self, access_token, gass_config):
"""Initialize the Google Assistant request handler."""
self.access_token = access_token
self.gass_config = gass_config
@asyncio.coroutine
def post(self, request: Request) -> Response:
@ -155,35 +84,7 @@ class GoogleAssistantView(HomeAssistantView):
return self.json_message(
"missing authorization", status_code=HTTP_UNAUTHORIZED)
data = yield from request.json() # type: dict
inputs = data.get('inputs') # type: list
if len(inputs) != 1:
_LOGGER.error('Too many inputs in request %d', len(inputs))
return self.json_message(
"too many inputs", status_code=HTTP_BAD_REQUEST)
request_id = data.get('requestId') # type: str
intent = inputs[0].get('intent')
payload = inputs[0].get('payload')
hass = request.app['hass'] # type: HomeAssistant
res = None
if intent == 'action.devices.SYNC':
res = yield from self.handle_sync(hass, request_id)
elif intent == 'action.devices.QUERY':
res = yield from self.handle_query(hass, request_id,
payload.get('devices', []))
elif intent == 'action.devices.EXECUTE':
res = yield from self.handle_execute(hass, request_id,
payload.get('commands', []))
if res:
return res
return self.json_message(
"invalid intent", status_code=HTTP_BAD_REQUEST)
def _make_actions_response(request_id: str, payload: dict) -> dict:
return {'requestId': request_id, 'payload': payload}
message = yield from request.json() # type: dict
result = yield from async_handle_message(
request.app['hass'], self.gass_config, message)
return self.json(result)

View File

@ -1,4 +1,6 @@
"""Support for Google Assistant Smart Home API."""
import asyncio
from collections import namedtuple
import logging
# Typing imports
@ -10,6 +12,7 @@ from homeassistant.helpers.entity import Entity # NOQA
from homeassistant.core import HomeAssistant # NOQA
from homeassistant.util import color
from homeassistant.util.unit_system import UnitSystem # NOQA
from homeassistant.util.decorator import Registry
from homeassistant.const import (
ATTR_SUPPORTED_FEATURES, ATTR_ENTITY_ID,
@ -34,6 +37,7 @@ from .const import (
CONF_ALIASES, CLIMATE_SUPPORTED_MODES
)
HANDLERS = Registry()
_LOGGER = logging.getLogger(__name__)
# Mapping is [actions schema, primary trait, optional features]
@ -65,9 +69,7 @@ MAPPING_COMPONENT = {
} # type: Dict[str, list]
def make_actions_response(request_id: str, payload: dict) -> dict:
"""Make response message."""
return {'requestId': request_id, 'payload': payload}
Config = namedtuple('GoogleAssistantConfig', 'should_expose,agent_user_id')
def entity_to_device(entity: Entity, units: UnitSystem):
@ -286,3 +288,98 @@ def determine_service(
return (SERVICE_TURN_OFF, service_data)
return (None, service_data)
@asyncio.coroutine
def async_handle_message(hass, config, message):
"""Handle incoming API messages."""
request_id = message.get('requestId') # type: str
inputs = message.get('inputs') # type: list
if len(inputs) > 1:
_LOGGER.warning('Got unexpected more than 1 input. %s', message)
# Only use first input
intent = inputs[0].get('intent')
payload = inputs[0].get('payload')
handler = HANDLERS.get(intent)
if handler:
result = yield from handler(hass, config, payload)
else:
result = {'errorCode': 'protocolError'}
return {'requestId': request_id, 'payload': result}
@HANDLERS.register('action.devices.SYNC')
@asyncio.coroutine
def async_devices_sync(hass, config, payload):
"""Handle action.devices.SYNC request."""
devices = []
for entity in hass.states.async_all():
if not config.should_expose(entity):
continue
device = entity_to_device(entity, hass.config.units)
if device is None:
_LOGGER.warning("No mapping for %s domain", entity.domain)
continue
devices.append(device)
return {
'agentUserId': config.agent_user_id,
'devices': devices,
}
@HANDLERS.register('action.devices.QUERY')
@asyncio.coroutine
def async_devices_query(hass, config, payload):
"""Handle action.devices.QUERY request."""
devices = {}
for device in payload.get('devices', []):
devid = device.get('id')
# In theory this should never happpen
if not devid:
_LOGGER.error('Device missing ID: %s', device)
continue
state = hass.states.get(devid)
if not state:
# If we can't find a state, the device is offline
devices[devid] = {'online': False}
devices[devid] = query_device(state, hass.config.units)
return {'devices': devices}
@HANDLERS.register('action.devices.EXECUTE')
@asyncio.coroutine
def handle_devices_execute(hass, config, payload):
"""Handle action.devices.EXECUTE request."""
commands = []
for command in payload.get('commands', []):
ent_ids = [ent.get('id') for ent in command.get('devices', [])]
for execution in command.get('execution'):
for eid in ent_ids:
success = False
domain = eid.split('.')[0]
(service, service_data) = determine_service(
eid, execution.get('command'), execution.get('params'),
hass.config.units)
if domain == "group":
domain = "homeassistant"
success = yield from hass.services.async_call(
domain, service, service_data, blocking=True)
result = {"ids": [eid], "states": {}}
if success:
result['status'] = 'SUCCESS'
else:
result['status'] = 'ERROR'
commands.append(result)
return {'commands': commands}

View File

@ -5,43 +5,41 @@ import json
from aiohttp.hdrs import CONTENT_TYPE, AUTHORIZATION
import pytest
from tests.common import get_test_instance_port
from homeassistant import core, const, setup
from homeassistant.components import (
fan, http, cover, light, switch, climate, async_setup, media_player)
fan, cover, light, switch, climate, async_setup, media_player)
from homeassistant.components import google_assistant as ga
from homeassistant.util.unit_system import IMPERIAL_SYSTEM
from . import DEMO_DEVICES
API_PASSWORD = "test1234"
SERVER_PORT = get_test_instance_port()
BASE_API_URL = "http://127.0.0.1:{}".format(SERVER_PORT)
HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
CONTENT_TYPE: const.CONTENT_TYPE_JSON,
}
AUTHCFG = {
'project_id': 'hasstest-1234',
'client_id': 'helloworld',
'access_token': 'superdoublesecret'
}
AUTH_HEADER = {AUTHORIZATION: 'Bearer {}'.format(AUTHCFG['access_token'])}
PROJECT_ID = 'hasstest-1234'
CLIENT_ID = 'helloworld'
ACCESS_TOKEN = 'superdoublesecret'
AUTH_HEADER = {AUTHORIZATION: 'Bearer {}'.format(ACCESS_TOKEN)}
@pytest.fixture
def assistant_client(loop, hass_fixture, test_client):
def assistant_client(loop, hass, test_client):
"""Create web client for the Google Assistant API."""
hass = hass_fixture
web_app = hass.http.app
loop.run_until_complete(
setup.async_setup_component(hass, 'google_assistant', {
'google_assistant': {
'project_id': PROJECT_ID,
'client_id': CLIENT_ID,
'access_token': ACCESS_TOKEN,
}
}))
ga.http.GoogleAssistantView(hass, AUTHCFG).register(web_app.router)
ga.auth.GoogleAssistantAuthView(hass, AUTHCFG).register(web_app.router)
return loop.run_until_complete(test_client(web_app))
return loop.run_until_complete(test_client(hass.http.app))
@pytest.fixture
@ -50,13 +48,6 @@ def hass_fixture(loop, hass):
# We need to do this to get access to homeassistant/turn_(on,off)
loop.run_until_complete(async_setup(hass, {core.DOMAIN: {}}))
loop.run_until_complete(
setup.async_setup_component(hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_SERVER_PORT: SERVER_PORT
}
}))
loop.run_until_complete(
setup.async_setup_component(hass, light.DOMAIN, {
'light': [{
@ -121,20 +112,20 @@ def hass_fixture(loop, hass):
@asyncio.coroutine
def test_auth(hass_fixture, assistant_client):
def test_auth(assistant_client):
"""Test the auth process."""
result = yield from assistant_client.get(
ga.const.GOOGLE_ASSISTANT_API_ENDPOINT + '/auth',
params={
'redirect_uri':
'http://testurl/r/{}'.format(AUTHCFG['project_id']),
'client_id': AUTHCFG['client_id'],
'http://testurl/r/{}'.format(PROJECT_ID),
'client_id': CLIENT_ID,
'state': 'random1234',
},
allow_redirects=False)
assert result.status == 301
loc = result.headers.get('Location')
assert AUTHCFG['access_token'] in loc
assert ACCESS_TOKEN in loc
@asyncio.coroutine
@ -167,9 +158,6 @@ def test_sync_request(hass_fixture, assistant_client):
@asyncio.coroutine
def test_query_request(hass_fixture, assistant_client):
"""Test a query request."""
# hass.states.set("light.bedroom", "on")
# hass.states.set("switch.outside", "off")
# res = _sync_req()
reqid = '5711642932632160984'
data = {
'requestId':
@ -301,9 +289,6 @@ def test_query_climate_request_f(hass_fixture, assistant_client):
@asyncio.coroutine
def test_execute_request(hass_fixture, assistant_client):
"""Test a execute request."""
# hass.states.set("light.bedroom", "on")
# hass.states.set("switch.outside", "off")
# res = _sync_req()
reqid = '5711642932632160985'
data = {
'requestId':

View File

@ -179,16 +179,6 @@ DETERMINE_SERVICE_TESTS = [{ # Test light brightness
}]
@asyncio.coroutine
def test_make_actions_response():
"""Test make response helper."""
reqid = 1234
payload = 'hello'
result = ga.smart_home.make_actions_response(reqid, payload)
assert result['requestId'] == reqid
assert result['payload'] == payload
@asyncio.coroutine
def test_determine_service():
"""Test all branches of determine service."""

View File

@ -7,6 +7,8 @@ from unittest import mock
from urllib.parse import urlparse, parse_qs
import yarl
from aiohttp.client_exceptions import ClientResponseError
class AiohttpClientMocker:
"""Mock Aiohttp client requests."""
@ -189,6 +191,12 @@ class AiohttpClientMockResponse:
"""Mock release."""
pass
def raise_for_status(self):
"""Raise error if status is 400 or higher."""
if self.status >= 400:
raise ClientResponseError(
None, None, code=self.status, headers=self.headers)
def close(self):
"""Mock close."""
pass