Add permissions check in service helper (#18596)

* Add permissions check in service helper

* Lint

* Fix tests

* Lint

* Typing

* Fix unused impoert
pull/18612/head
Paulus Schoutsen 2018-11-21 12:26:08 +01:00 committed by GitHub
parent 8aa2cefd75
commit 36c31a6293
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 268 additions and 46 deletions

View File

@ -118,6 +118,10 @@ class AuthManager:
"""Retrieve a user.""" """Retrieve a user."""
return await self._store.async_get_user(user_id) return await self._store.async_get_user(user_id)
async def async_get_group(self, group_id: str) -> Optional[models.Group]:
"""Retrieve all groups."""
return await self._store.async_get_group(group_id)
async def async_get_user_by_credentials( async def async_get_user_by_credentials(
self, credentials: models.Credentials) -> Optional[models.User]: self, credentials: models.Credentials) -> Optional[models.User]:
"""Get a user by credential, return None if not found.""" """Get a user by credential, return None if not found."""

View File

@ -45,6 +45,14 @@ class AuthStore:
return list(self._groups.values()) return list(self._groups.values())
async def async_get_group(self, group_id: str) -> Optional[models.Group]:
"""Retrieve all users."""
if self._groups is None:
await self._async_load()
assert self._groups is not None
return self._groups.get(group_id)
async def async_get_users(self) -> List[models.User]: async def async_get_users(self) -> List[models.User]:
"""Retrieve all users.""" """Retrieve all users."""
if self._users is None: if self._users is None:

View File

@ -1,24 +1,24 @@
"""The exceptions used by Home Assistant.""" """The exceptions used by Home Assistant."""
from typing import Optional, Tuple, TYPE_CHECKING
import jinja2 import jinja2
# pylint: disable=using-constant-test
if TYPE_CHECKING:
# pylint: disable=unused-import
from .core import Context # noqa
class HomeAssistantError(Exception): class HomeAssistantError(Exception):
"""General Home Assistant exception occurred.""" """General Home Assistant exception occurred."""
pass
class InvalidEntityFormatError(HomeAssistantError): class InvalidEntityFormatError(HomeAssistantError):
"""When an invalid formatted entity is encountered.""" """When an invalid formatted entity is encountered."""
pass
class NoEntitySpecifiedError(HomeAssistantError): class NoEntitySpecifiedError(HomeAssistantError):
"""When no entity is specified.""" """When no entity is specified."""
pass
class TemplateError(HomeAssistantError): class TemplateError(HomeAssistantError):
"""Error during template rendering.""" """Error during template rendering."""
@ -32,16 +32,29 @@ class TemplateError(HomeAssistantError):
class PlatformNotReady(HomeAssistantError): class PlatformNotReady(HomeAssistantError):
"""Error to indicate that platform is not ready.""" """Error to indicate that platform is not ready."""
pass
class ConfigEntryNotReady(HomeAssistantError): class ConfigEntryNotReady(HomeAssistantError):
"""Error to indicate that config entry is not ready.""" """Error to indicate that config entry is not ready."""
pass
class InvalidStateError(HomeAssistantError): class InvalidStateError(HomeAssistantError):
"""When an invalid state is encountered.""" """When an invalid state is encountered."""
pass
class Unauthorized(HomeAssistantError):
"""When an action is unauthorized."""
def __init__(self, context: Optional['Context'] = None,
user_id: Optional[str] = None,
entity_id: Optional[str] = None,
permission: Optional[Tuple[str]] = None) -> None:
"""Unauthorized error."""
super().__init__(self.__class__.__name__)
self.context = context
self.user_id = user_id
self.entity_id = entity_id
self.permission = permission
class UnknownUser(Unauthorized):
"""When call is made with user ID that doesn't exist."""

View File

@ -5,9 +5,10 @@ from os import path
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.permissions.const import POLICY_CONTROL
from homeassistant.const import ATTR_ENTITY_ID from homeassistant.const import ATTR_ENTITY_ID
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser
from homeassistant.helpers import template from homeassistant.helpers import template
from homeassistant.loader import get_component, bind_hass from homeassistant.loader import get_component, bind_hass
from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml import load_yaml
@ -187,23 +188,75 @@ async def entity_service_call(hass, platforms, func, call):
Calls all platforms simultaneously. Calls all platforms simultaneously.
""" """
tasks = [] if call.context.user_id:
all_entities = ATTR_ENTITY_ID not in call.data user = await hass.auth.async_get_user(call.context.user_id)
if not all_entities: if user is None:
raise UnknownUser(context=call.context)
perms = user.permissions
else:
perms = None
# Are we trying to target all entities
target_all_entities = ATTR_ENTITY_ID not in call.data
if not target_all_entities:
# A set of entities we're trying to target.
entity_ids = set( entity_ids = set(
extract_entity_ids(hass, call, True)) extract_entity_ids(hass, call, True))
# If the service function is a string, we'll pass it the service call data
if isinstance(func, str): if isinstance(func, str):
data = {key: val for key, val in call.data.items() data = {key: val for key, val in call.data.items()
if key != ATTR_ENTITY_ID} if key != ATTR_ENTITY_ID}
# If the service function is not a string, we pass the service call
else: else:
data = call data = call
# Check the permissions
# A list with for each platform in platforms a list of entities to call
# the service on.
platforms_entities = []
if perms is None:
for platform in platforms:
if target_all_entities:
platforms_entities.append(list(platform.entities.values()))
else:
platforms_entities.append([
entity for entity in platform.entities.values()
if entity.entity_id in entity_ids
])
elif target_all_entities:
# If we target all entities, we will select all entities the user
# is allowed to control.
for platform in platforms:
platforms_entities.append([
entity for entity in platform.entities.values()
if perms.check_entity(entity.entity_id, POLICY_CONTROL)])
else:
for platform in platforms:
platform_entities = []
for entity in platform.entities.values():
if entity.entity_id not in entity_ids:
continue
if not perms.check_entity(entity.entity_id, POLICY_CONTROL):
raise Unauthorized(
context=call.context,
entity_id=entity.entity_id,
permission=POLICY_CONTROL
)
platform_entities.append(entity)
platforms_entities.append(platform_entities)
tasks = [ tasks = [
_handle_service_platform_call(func, data, [ _handle_service_platform_call(func, data, entities, call.context)
entity for entity in platform.entities.values() for platform, entities in zip(platforms, platforms_entities)
if all_entities or entity.entity_id in entity_ids
], call.context) for platform in platforms
] ]
if tasks: if tasks:

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.components.websocket_api.http import URL from homeassistant.components.websocket_api.http import URL
from homeassistant.components.websocket_api.auth import ( from homeassistant.components.websocket_api.auth import (
@ -77,3 +78,19 @@ def hass_access_token(hass):
refresh_token = hass.loop.run_until_complete( refresh_token = hass.loop.run_until_complete(
hass.auth.async_create_refresh_token(user, CLIENT_ID)) hass.auth.async_create_refresh_token(user, CLIENT_ID))
yield hass.auth.async_create_access_token(refresh_token) yield hass.auth.async_create_access_token(refresh_token)
@pytest.fixture
def hass_admin_user(hass):
"""Return a Home Assistant admin user."""
admin_group = hass.loop.run_until_complete(hass.auth.async_get_group(
GROUP_ID_ADMIN))
return MockUser(groups=[admin_group]).add_to_hass(hass)
@pytest.fixture
def hass_read_only_user(hass):
"""Return a Home Assistant read only user."""
read_only_group = hass.loop.run_until_complete(hass.auth.async_get_group(
GROUP_ID_READ_ONLY))
return MockUser(groups=[read_only_group]).add_to_hass(hass)

View File

@ -234,7 +234,7 @@ def test_no_initial_state_and_no_restore_state(hass):
assert int(state.state) == 0 assert int(state.state) == 0
async def test_counter_context(hass): async def test_counter_context(hass, hass_admin_user):
"""Test that counter context works.""" """Test that counter context works."""
assert await async_setup_component(hass, 'counter', { assert await async_setup_component(hass, 'counter', {
'counter': { 'counter': {
@ -247,9 +247,9 @@ async def test_counter_context(hass):
await hass.services.async_call('counter', 'increment', { await hass.services.async_call('counter', 'increment', {
'entity_id': state.entity_id, 'entity_id': state.entity_id,
}, True, Context(user_id='abcd')) }, True, Context(user_id=hass_admin_user.id))
state2 = hass.states.get('counter.test') state2 = hass.states.get('counter.test')
assert state2 is not None assert state2 is not None
assert state.state != state2.state assert state.state != state2.state
assert state2.context.user_id == 'abcd' assert state2.context.user_id == hass_admin_user.id

View File

@ -476,7 +476,7 @@ async def test_intent_set_color_and_brightness(hass):
assert call.data.get(light.ATTR_BRIGHTNESS_PCT) == 20 assert call.data.get(light.ATTR_BRIGHTNESS_PCT) == 20
async def test_light_context(hass): async def test_light_context(hass, hass_admin_user):
"""Test that light context works.""" """Test that light context works."""
assert await async_setup_component(hass, 'light', { assert await async_setup_component(hass, 'light', {
'light': { 'light': {
@ -489,9 +489,9 @@ async def test_light_context(hass):
await hass.services.async_call('light', 'toggle', { await hass.services.async_call('light', 'toggle', {
'entity_id': state.entity_id, 'entity_id': state.entity_id,
}, True, core.Context(user_id='abcd')) }, True, core.Context(user_id=hass_admin_user.id))
state2 = hass.states.get('light.ceiling') state2 = hass.states.get('light.ceiling')
assert state2 is not None assert state2 is not None
assert state.state != state2.state assert state.state != state2.state
assert state2.context.user_id == 'abcd' assert state2.context.user_id == hass_admin_user.id

View File

@ -91,7 +91,7 @@ class TestSwitch(unittest.TestCase):
) )
async def test_switch_context(hass): async def test_switch_context(hass, hass_admin_user):
"""Test that switch context works.""" """Test that switch context works."""
assert await async_setup_component(hass, 'switch', { assert await async_setup_component(hass, 'switch', {
'switch': { 'switch': {
@ -104,9 +104,9 @@ async def test_switch_context(hass):
await hass.services.async_call('switch', 'toggle', { await hass.services.async_call('switch', 'toggle', {
'entity_id': state.entity_id, 'entity_id': state.entity_id,
}, True, core.Context(user_id='abcd')) }, True, core.Context(user_id=hass_admin_user.id))
state2 = hass.states.get('switch.ac') state2 = hass.states.get('switch.ac')
assert state2 is not None assert state2 is not None
assert state.state != state2.state assert state.state != state2.state
assert state2.context.user_id == 'abcd' assert state2.context.user_id == hass_admin_user.id

View File

@ -147,7 +147,7 @@ def test_initial_state_overrules_restore_state(hass):
assert state.state == 'on' assert state.state == 'on'
async def test_input_boolean_context(hass): async def test_input_boolean_context(hass, hass_admin_user):
"""Test that input_boolean context works.""" """Test that input_boolean context works."""
assert await async_setup_component(hass, 'input_boolean', { assert await async_setup_component(hass, 'input_boolean', {
'input_boolean': { 'input_boolean': {
@ -160,9 +160,9 @@ async def test_input_boolean_context(hass):
await hass.services.async_call('input_boolean', 'turn_off', { await hass.services.async_call('input_boolean', 'turn_off', {
'entity_id': state.entity_id, 'entity_id': state.entity_id,
}, True, Context(user_id='abcd')) }, True, Context(user_id=hass_admin_user.id))
state2 = hass.states.get('input_boolean.ac') state2 = hass.states.get('input_boolean.ac')
assert state2 is not None assert state2 is not None
assert state.state != state2.state assert state.state != state2.state
assert state2.context.user_id == 'abcd' assert state2.context.user_id == hass_admin_user.id

View File

@ -195,7 +195,7 @@ def test_restore_state(hass):
assert state_bogus.state == str(initial) assert state_bogus.state == str(initial)
async def test_input_datetime_context(hass): async def test_input_datetime_context(hass, hass_admin_user):
"""Test that input_datetime context works.""" """Test that input_datetime context works."""
assert await async_setup_component(hass, 'input_datetime', { assert await async_setup_component(hass, 'input_datetime', {
'input_datetime': { 'input_datetime': {
@ -211,9 +211,9 @@ async def test_input_datetime_context(hass):
await hass.services.async_call('input_datetime', 'set_datetime', { await hass.services.async_call('input_datetime', 'set_datetime', {
'entity_id': state.entity_id, 'entity_id': state.entity_id,
'date': '2018-01-02' 'date': '2018-01-02'
}, True, Context(user_id='abcd')) }, True, Context(user_id=hass_admin_user.id))
state2 = hass.states.get('input_datetime.only_date') state2 = hass.states.get('input_datetime.only_date')
assert state2 is not None assert state2 is not None
assert state.state != state2.state assert state.state != state2.state
assert state2.context.user_id == 'abcd' assert state2.context.user_id == hass_admin_user.id

View File

@ -266,7 +266,7 @@ def test_no_initial_state_and_no_restore_state(hass):
assert float(state.state) == 0 assert float(state.state) == 0
async def test_input_number_context(hass): async def test_input_number_context(hass, hass_admin_user):
"""Test that input_number context works.""" """Test that input_number context works."""
assert await async_setup_component(hass, 'input_number', { assert await async_setup_component(hass, 'input_number', {
'input_number': { 'input_number': {
@ -282,9 +282,9 @@ async def test_input_number_context(hass):
await hass.services.async_call('input_number', 'increment', { await hass.services.async_call('input_number', 'increment', {
'entity_id': state.entity_id, 'entity_id': state.entity_id,
}, True, Context(user_id='abcd')) }, True, Context(user_id=hass_admin_user.id))
state2 = hass.states.get('input_number.b1') state2 = hass.states.get('input_number.b1')
assert state2 is not None assert state2 is not None
assert state.state != state2.state assert state.state != state2.state
assert state2.context.user_id == 'abcd' assert state2.context.user_id == hass_admin_user.id

View File

@ -302,7 +302,7 @@ def test_initial_state_overrules_restore_state(hass):
assert state.state == 'middle option' assert state.state == 'middle option'
async def test_input_select_context(hass): async def test_input_select_context(hass, hass_admin_user):
"""Test that input_select context works.""" """Test that input_select context works."""
assert await async_setup_component(hass, 'input_select', { assert await async_setup_component(hass, 'input_select', {
'input_select': { 'input_select': {
@ -321,9 +321,9 @@ async def test_input_select_context(hass):
await hass.services.async_call('input_select', 'select_next', { await hass.services.async_call('input_select', 'select_next', {
'entity_id': state.entity_id, 'entity_id': state.entity_id,
}, True, Context(user_id='abcd')) }, True, Context(user_id=hass_admin_user.id))
state2 = hass.states.get('input_select.s1') state2 = hass.states.get('input_select.s1')
assert state2 is not None assert state2 is not None
assert state.state != state2.state assert state.state != state2.state
assert state2.context.user_id == 'abcd' assert state2.context.user_id == hass_admin_user.id

View File

@ -184,7 +184,7 @@ def test_no_initial_state_and_no_restore_state(hass):
assert str(state.state) == 'unknown' assert str(state.state) == 'unknown'
async def test_input_text_context(hass): async def test_input_text_context(hass, hass_admin_user):
"""Test that input_text context works.""" """Test that input_text context works."""
assert await async_setup_component(hass, 'input_text', { assert await async_setup_component(hass, 'input_text', {
'input_text': { 'input_text': {
@ -200,9 +200,9 @@ async def test_input_text_context(hass):
await hass.services.async_call('input_text', 'set_value', { await hass.services.async_call('input_text', 'set_value', {
'entity_id': state.entity_id, 'entity_id': state.entity_id,
'value': 'new_value', 'value': 'new_value',
}, True, Context(user_id='abcd')) }, True, Context(user_id=hass_admin_user.id))
state2 = hass.states.get('input_text.t1') state2 = hass.states.get('input_text.t1')
assert state2 is not None assert state2 is not None
assert state.state != state2.state assert state.state != state2.state
assert state2.context.user_id == 'abcd' assert state2.context.user_id == hass_admin_user.id

View File

@ -1,18 +1,49 @@
"""Test service helpers.""" """Test service helpers."""
import asyncio import asyncio
from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
import unittest import unittest
from unittest.mock import patch from unittest.mock import Mock, patch
import pytest
# To prevent circular import when running just this file # To prevent circular import when running just this file
import homeassistant.components # noqa import homeassistant.components # noqa
from homeassistant import core as ha, loader from homeassistant import core as ha, loader, exceptions
from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID
from homeassistant.helpers import service, template from homeassistant.helpers import service, template
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.auth.permissions import PolicyPermissions
from tests.common import get_test_home_assistant, mock_service from tests.common import get_test_home_assistant, mock_service, mock_coro
@pytest.fixture
def mock_service_platform_call():
"""Mock service platform call."""
with patch('homeassistant.helpers.service._handle_service_platform_call',
side_effect=lambda *args: mock_coro()) as mock_call:
yield mock_call
@pytest.fixture
def mock_entities():
"""Return mock entities in an ordered dict."""
kitchen = Mock(
entity_id='light.kitchen',
available=True,
should_poll=False,
)
living_room = Mock(
entity_id='light.living_room',
available=True,
should_poll=False,
)
entities = OrderedDict()
entities[kitchen.entity_id] = kitchen
entities[living_room.entity_id] = living_room
return entities
class TestServiceHelpers(unittest.TestCase): class TestServiceHelpers(unittest.TestCase):
@ -179,3 +210,99 @@ def test_async_get_all_descriptions(hass):
assert 'description' in descriptions[logger.DOMAIN]['set_level'] assert 'description' in descriptions[logger.DOMAIN]['set_level']
assert 'fields' in descriptions[logger.DOMAIN]['set_level'] assert 'fields' in descriptions[logger.DOMAIN]['set_level']
async def test_call_context_user_not_exist(hass):
"""Check we don't allow deleted users to do things."""
with pytest.raises(exceptions.UnknownUser) as err:
await service.entity_service_call(hass, [], Mock(), ha.ServiceCall(
'test_domain', 'test_service', context=ha.Context(
user_id='non-existing')))
assert err.value.context.user_id == 'non-existing'
async def test_call_context_target_all(hass, mock_service_platform_call,
mock_entities):
"""Check we only target allowed entities if targetting all."""
with patch('homeassistant.auth.AuthManager.async_get_user',
return_value=mock_coro(Mock(permissions=PolicyPermissions({
'entities': {
'entity_ids': {
'light.kitchen': True
}
}
})))):
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service',
context=ha.Context(user_id='mock-id')))
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
assert entities == [mock_entities['light.kitchen']]
async def test_call_context_target_specific(hass, mock_service_platform_call,
mock_entities):
"""Check targeting specific entities."""
with patch('homeassistant.auth.AuthManager.async_get_user',
return_value=mock_coro(Mock(permissions=PolicyPermissions({
'entities': {
'entity_ids': {
'light.kitchen': True
}
}
})))):
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service', {
'entity_id': 'light.kitchen'
}, context=ha.Context(user_id='mock-id')))
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
assert entities == [mock_entities['light.kitchen']]
async def test_call_context_target_specific_no_auth(
hass, mock_service_platform_call, mock_entities):
"""Check targeting specific entities without auth."""
with pytest.raises(exceptions.Unauthorized) as err:
with patch('homeassistant.auth.AuthManager.async_get_user',
return_value=mock_coro(Mock(
permissions=PolicyPermissions({})))):
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service', {
'entity_id': 'light.kitchen'
}, context=ha.Context(user_id='mock-id')))
assert err.value.context.user_id == 'mock-id'
assert err.value.entity_id == 'light.kitchen'
async def test_call_no_context_target_all(hass, mock_service_platform_call,
mock_entities):
"""Check we target all if no user context given."""
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service'))
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
assert entities == list(mock_entities.values())
async def test_call_no_context_target_specific(
hass, mock_service_platform_call, mock_entities):
"""Check we can target specified entities."""
await service.entity_service_call(hass, [
Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service', {
'entity_id': ['light.kitchen', 'light.non-existing']
}))
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
assert entities == [mock_entities['light.kitchen']]