Add permissions check in service helper (#18596)
* Add permissions check in service helper * Lint * Fix tests * Lint * Typing * Fix unused impoertpull/18612/head
parent
8aa2cefd75
commit
36c31a6293
|
@ -118,6 +118,10 @@ class AuthManager:
|
||||||
"""Retrieve a user."""
|
"""Retrieve a user."""
|
||||||
return await self._store.async_get_user(user_id)
|
return await self._store.async_get_user(user_id)
|
||||||
|
|
||||||
|
async def async_get_group(self, group_id: str) -> Optional[models.Group]:
|
||||||
|
"""Retrieve all groups."""
|
||||||
|
return await self._store.async_get_group(group_id)
|
||||||
|
|
||||||
async def async_get_user_by_credentials(
|
async def async_get_user_by_credentials(
|
||||||
self, credentials: models.Credentials) -> Optional[models.User]:
|
self, credentials: models.Credentials) -> Optional[models.User]:
|
||||||
"""Get a user by credential, return None if not found."""
|
"""Get a user by credential, return None if not found."""
|
||||||
|
|
|
@ -45,6 +45,14 @@ class AuthStore:
|
||||||
|
|
||||||
return list(self._groups.values())
|
return list(self._groups.values())
|
||||||
|
|
||||||
|
async def async_get_group(self, group_id: str) -> Optional[models.Group]:
|
||||||
|
"""Retrieve all users."""
|
||||||
|
if self._groups is None:
|
||||||
|
await self._async_load()
|
||||||
|
assert self._groups is not None
|
||||||
|
|
||||||
|
return self._groups.get(group_id)
|
||||||
|
|
||||||
async def async_get_users(self) -> List[models.User]:
|
async def async_get_users(self) -> List[models.User]:
|
||||||
"""Retrieve all users."""
|
"""Retrieve all users."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
|
|
|
@ -1,24 +1,24 @@
|
||||||
"""The exceptions used by Home Assistant."""
|
"""The exceptions used by Home Assistant."""
|
||||||
|
from typing import Optional, Tuple, TYPE_CHECKING
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
|
# pylint: disable=using-constant-test
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
from .core import Context # noqa
|
||||||
|
|
||||||
|
|
||||||
class HomeAssistantError(Exception):
|
class HomeAssistantError(Exception):
|
||||||
"""General Home Assistant exception occurred."""
|
"""General Home Assistant exception occurred."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidEntityFormatError(HomeAssistantError):
|
class InvalidEntityFormatError(HomeAssistantError):
|
||||||
"""When an invalid formatted entity is encountered."""
|
"""When an invalid formatted entity is encountered."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NoEntitySpecifiedError(HomeAssistantError):
|
class NoEntitySpecifiedError(HomeAssistantError):
|
||||||
"""When no entity is specified."""
|
"""When no entity is specified."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TemplateError(HomeAssistantError):
|
class TemplateError(HomeAssistantError):
|
||||||
"""Error during template rendering."""
|
"""Error during template rendering."""
|
||||||
|
@ -32,16 +32,29 @@ class TemplateError(HomeAssistantError):
|
||||||
class PlatformNotReady(HomeAssistantError):
|
class PlatformNotReady(HomeAssistantError):
|
||||||
"""Error to indicate that platform is not ready."""
|
"""Error to indicate that platform is not ready."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigEntryNotReady(HomeAssistantError):
|
class ConfigEntryNotReady(HomeAssistantError):
|
||||||
"""Error to indicate that config entry is not ready."""
|
"""Error to indicate that config entry is not ready."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidStateError(HomeAssistantError):
|
class InvalidStateError(HomeAssistantError):
|
||||||
"""When an invalid state is encountered."""
|
"""When an invalid state is encountered."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
class Unauthorized(HomeAssistantError):
|
||||||
|
"""When an action is unauthorized."""
|
||||||
|
|
||||||
|
def __init__(self, context: Optional['Context'] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
entity_id: Optional[str] = None,
|
||||||
|
permission: Optional[Tuple[str]] = None) -> None:
|
||||||
|
"""Unauthorized error."""
|
||||||
|
super().__init__(self.__class__.__name__)
|
||||||
|
self.context = context
|
||||||
|
self.user_id = user_id
|
||||||
|
self.entity_id = entity_id
|
||||||
|
self.permission = permission
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownUser(Unauthorized):
|
||||||
|
"""When call is made with user ID that doesn't exist."""
|
||||||
|
|
|
@ -5,9 +5,10 @@ from os import path
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.auth.permissions.const import POLICY_CONTROL
|
||||||
from homeassistant.const import ATTR_ENTITY_ID
|
from homeassistant.const import ATTR_ENTITY_ID
|
||||||
import homeassistant.core as ha
|
import homeassistant.core as ha
|
||||||
from homeassistant.exceptions import TemplateError
|
from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser
|
||||||
from homeassistant.helpers import template
|
from homeassistant.helpers import template
|
||||||
from homeassistant.loader import get_component, bind_hass
|
from homeassistant.loader import get_component, bind_hass
|
||||||
from homeassistant.util.yaml import load_yaml
|
from homeassistant.util.yaml import load_yaml
|
||||||
|
@ -187,23 +188,75 @@ async def entity_service_call(hass, platforms, func, call):
|
||||||
|
|
||||||
Calls all platforms simultaneously.
|
Calls all platforms simultaneously.
|
||||||
"""
|
"""
|
||||||
tasks = []
|
if call.context.user_id:
|
||||||
all_entities = ATTR_ENTITY_ID not in call.data
|
user = await hass.auth.async_get_user(call.context.user_id)
|
||||||
if not all_entities:
|
if user is None:
|
||||||
|
raise UnknownUser(context=call.context)
|
||||||
|
perms = user.permissions
|
||||||
|
else:
|
||||||
|
perms = None
|
||||||
|
|
||||||
|
# Are we trying to target all entities
|
||||||
|
target_all_entities = ATTR_ENTITY_ID not in call.data
|
||||||
|
|
||||||
|
if not target_all_entities:
|
||||||
|
# A set of entities we're trying to target.
|
||||||
entity_ids = set(
|
entity_ids = set(
|
||||||
extract_entity_ids(hass, call, True))
|
extract_entity_ids(hass, call, True))
|
||||||
|
|
||||||
|
# If the service function is a string, we'll pass it the service call data
|
||||||
if isinstance(func, str):
|
if isinstance(func, str):
|
||||||
data = {key: val for key, val in call.data.items()
|
data = {key: val for key, val in call.data.items()
|
||||||
if key != ATTR_ENTITY_ID}
|
if key != ATTR_ENTITY_ID}
|
||||||
|
# If the service function is not a string, we pass the service call
|
||||||
else:
|
else:
|
||||||
data = call
|
data = call
|
||||||
|
|
||||||
|
# Check the permissions
|
||||||
|
|
||||||
|
# A list with for each platform in platforms a list of entities to call
|
||||||
|
# the service on.
|
||||||
|
platforms_entities = []
|
||||||
|
|
||||||
|
if perms is None:
|
||||||
|
for platform in platforms:
|
||||||
|
if target_all_entities:
|
||||||
|
platforms_entities.append(list(platform.entities.values()))
|
||||||
|
else:
|
||||||
|
platforms_entities.append([
|
||||||
|
entity for entity in platform.entities.values()
|
||||||
|
if entity.entity_id in entity_ids
|
||||||
|
])
|
||||||
|
|
||||||
|
elif target_all_entities:
|
||||||
|
# If we target all entities, we will select all entities the user
|
||||||
|
# is allowed to control.
|
||||||
|
for platform in platforms:
|
||||||
|
platforms_entities.append([
|
||||||
|
entity for entity in platform.entities.values()
|
||||||
|
if perms.check_entity(entity.entity_id, POLICY_CONTROL)])
|
||||||
|
|
||||||
|
else:
|
||||||
|
for platform in platforms:
|
||||||
|
platform_entities = []
|
||||||
|
for entity in platform.entities.values():
|
||||||
|
if entity.entity_id not in entity_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not perms.check_entity(entity.entity_id, POLICY_CONTROL):
|
||||||
|
raise Unauthorized(
|
||||||
|
context=call.context,
|
||||||
|
entity_id=entity.entity_id,
|
||||||
|
permission=POLICY_CONTROL
|
||||||
|
)
|
||||||
|
|
||||||
|
platform_entities.append(entity)
|
||||||
|
|
||||||
|
platforms_entities.append(platform_entities)
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
_handle_service_platform_call(func, data, [
|
_handle_service_platform_call(func, data, entities, call.context)
|
||||||
entity for entity in platform.entities.values()
|
for platform, entities in zip(platforms, platforms_entities)
|
||||||
if all_entities or entity.entity_id in entity_ids
|
|
||||||
], call.context) for platform in platforms
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if tasks:
|
if tasks:
|
||||||
|
|
|
@ -3,6 +3,7 @@ from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.components.websocket_api.http import URL
|
from homeassistant.components.websocket_api.http import URL
|
||||||
from homeassistant.components.websocket_api.auth import (
|
from homeassistant.components.websocket_api.auth import (
|
||||||
|
@ -77,3 +78,19 @@ def hass_access_token(hass):
|
||||||
refresh_token = hass.loop.run_until_complete(
|
refresh_token = hass.loop.run_until_complete(
|
||||||
hass.auth.async_create_refresh_token(user, CLIENT_ID))
|
hass.auth.async_create_refresh_token(user, CLIENT_ID))
|
||||||
yield hass.auth.async_create_access_token(refresh_token)
|
yield hass.auth.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def hass_admin_user(hass):
|
||||||
|
"""Return a Home Assistant admin user."""
|
||||||
|
admin_group = hass.loop.run_until_complete(hass.auth.async_get_group(
|
||||||
|
GROUP_ID_ADMIN))
|
||||||
|
return MockUser(groups=[admin_group]).add_to_hass(hass)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def hass_read_only_user(hass):
|
||||||
|
"""Return a Home Assistant read only user."""
|
||||||
|
read_only_group = hass.loop.run_until_complete(hass.auth.async_get_group(
|
||||||
|
GROUP_ID_READ_ONLY))
|
||||||
|
return MockUser(groups=[read_only_group]).add_to_hass(hass)
|
||||||
|
|
|
@ -234,7 +234,7 @@ def test_no_initial_state_and_no_restore_state(hass):
|
||||||
assert int(state.state) == 0
|
assert int(state.state) == 0
|
||||||
|
|
||||||
|
|
||||||
async def test_counter_context(hass):
|
async def test_counter_context(hass, hass_admin_user):
|
||||||
"""Test that counter context works."""
|
"""Test that counter context works."""
|
||||||
assert await async_setup_component(hass, 'counter', {
|
assert await async_setup_component(hass, 'counter', {
|
||||||
'counter': {
|
'counter': {
|
||||||
|
@ -247,9 +247,9 @@ async def test_counter_context(hass):
|
||||||
|
|
||||||
await hass.services.async_call('counter', 'increment', {
|
await hass.services.async_call('counter', 'increment', {
|
||||||
'entity_id': state.entity_id,
|
'entity_id': state.entity_id,
|
||||||
}, True, Context(user_id='abcd'))
|
}, True, Context(user_id=hass_admin_user.id))
|
||||||
|
|
||||||
state2 = hass.states.get('counter.test')
|
state2 = hass.states.get('counter.test')
|
||||||
assert state2 is not None
|
assert state2 is not None
|
||||||
assert state.state != state2.state
|
assert state.state != state2.state
|
||||||
assert state2.context.user_id == 'abcd'
|
assert state2.context.user_id == hass_admin_user.id
|
||||||
|
|
|
@ -476,7 +476,7 @@ async def test_intent_set_color_and_brightness(hass):
|
||||||
assert call.data.get(light.ATTR_BRIGHTNESS_PCT) == 20
|
assert call.data.get(light.ATTR_BRIGHTNESS_PCT) == 20
|
||||||
|
|
||||||
|
|
||||||
async def test_light_context(hass):
|
async def test_light_context(hass, hass_admin_user):
|
||||||
"""Test that light context works."""
|
"""Test that light context works."""
|
||||||
assert await async_setup_component(hass, 'light', {
|
assert await async_setup_component(hass, 'light', {
|
||||||
'light': {
|
'light': {
|
||||||
|
@ -489,9 +489,9 @@ async def test_light_context(hass):
|
||||||
|
|
||||||
await hass.services.async_call('light', 'toggle', {
|
await hass.services.async_call('light', 'toggle', {
|
||||||
'entity_id': state.entity_id,
|
'entity_id': state.entity_id,
|
||||||
}, True, core.Context(user_id='abcd'))
|
}, True, core.Context(user_id=hass_admin_user.id))
|
||||||
|
|
||||||
state2 = hass.states.get('light.ceiling')
|
state2 = hass.states.get('light.ceiling')
|
||||||
assert state2 is not None
|
assert state2 is not None
|
||||||
assert state.state != state2.state
|
assert state.state != state2.state
|
||||||
assert state2.context.user_id == 'abcd'
|
assert state2.context.user_id == hass_admin_user.id
|
||||||
|
|
|
@ -91,7 +91,7 @@ class TestSwitch(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_switch_context(hass):
|
async def test_switch_context(hass, hass_admin_user):
|
||||||
"""Test that switch context works."""
|
"""Test that switch context works."""
|
||||||
assert await async_setup_component(hass, 'switch', {
|
assert await async_setup_component(hass, 'switch', {
|
||||||
'switch': {
|
'switch': {
|
||||||
|
@ -104,9 +104,9 @@ async def test_switch_context(hass):
|
||||||
|
|
||||||
await hass.services.async_call('switch', 'toggle', {
|
await hass.services.async_call('switch', 'toggle', {
|
||||||
'entity_id': state.entity_id,
|
'entity_id': state.entity_id,
|
||||||
}, True, core.Context(user_id='abcd'))
|
}, True, core.Context(user_id=hass_admin_user.id))
|
||||||
|
|
||||||
state2 = hass.states.get('switch.ac')
|
state2 = hass.states.get('switch.ac')
|
||||||
assert state2 is not None
|
assert state2 is not None
|
||||||
assert state.state != state2.state
|
assert state.state != state2.state
|
||||||
assert state2.context.user_id == 'abcd'
|
assert state2.context.user_id == hass_admin_user.id
|
||||||
|
|
|
@ -147,7 +147,7 @@ def test_initial_state_overrules_restore_state(hass):
|
||||||
assert state.state == 'on'
|
assert state.state == 'on'
|
||||||
|
|
||||||
|
|
||||||
async def test_input_boolean_context(hass):
|
async def test_input_boolean_context(hass, hass_admin_user):
|
||||||
"""Test that input_boolean context works."""
|
"""Test that input_boolean context works."""
|
||||||
assert await async_setup_component(hass, 'input_boolean', {
|
assert await async_setup_component(hass, 'input_boolean', {
|
||||||
'input_boolean': {
|
'input_boolean': {
|
||||||
|
@ -160,9 +160,9 @@ async def test_input_boolean_context(hass):
|
||||||
|
|
||||||
await hass.services.async_call('input_boolean', 'turn_off', {
|
await hass.services.async_call('input_boolean', 'turn_off', {
|
||||||
'entity_id': state.entity_id,
|
'entity_id': state.entity_id,
|
||||||
}, True, Context(user_id='abcd'))
|
}, True, Context(user_id=hass_admin_user.id))
|
||||||
|
|
||||||
state2 = hass.states.get('input_boolean.ac')
|
state2 = hass.states.get('input_boolean.ac')
|
||||||
assert state2 is not None
|
assert state2 is not None
|
||||||
assert state.state != state2.state
|
assert state.state != state2.state
|
||||||
assert state2.context.user_id == 'abcd'
|
assert state2.context.user_id == hass_admin_user.id
|
||||||
|
|
|
@ -195,7 +195,7 @@ def test_restore_state(hass):
|
||||||
assert state_bogus.state == str(initial)
|
assert state_bogus.state == str(initial)
|
||||||
|
|
||||||
|
|
||||||
async def test_input_datetime_context(hass):
|
async def test_input_datetime_context(hass, hass_admin_user):
|
||||||
"""Test that input_datetime context works."""
|
"""Test that input_datetime context works."""
|
||||||
assert await async_setup_component(hass, 'input_datetime', {
|
assert await async_setup_component(hass, 'input_datetime', {
|
||||||
'input_datetime': {
|
'input_datetime': {
|
||||||
|
@ -211,9 +211,9 @@ async def test_input_datetime_context(hass):
|
||||||
await hass.services.async_call('input_datetime', 'set_datetime', {
|
await hass.services.async_call('input_datetime', 'set_datetime', {
|
||||||
'entity_id': state.entity_id,
|
'entity_id': state.entity_id,
|
||||||
'date': '2018-01-02'
|
'date': '2018-01-02'
|
||||||
}, True, Context(user_id='abcd'))
|
}, True, Context(user_id=hass_admin_user.id))
|
||||||
|
|
||||||
state2 = hass.states.get('input_datetime.only_date')
|
state2 = hass.states.get('input_datetime.only_date')
|
||||||
assert state2 is not None
|
assert state2 is not None
|
||||||
assert state.state != state2.state
|
assert state.state != state2.state
|
||||||
assert state2.context.user_id == 'abcd'
|
assert state2.context.user_id == hass_admin_user.id
|
||||||
|
|
|
@ -266,7 +266,7 @@ def test_no_initial_state_and_no_restore_state(hass):
|
||||||
assert float(state.state) == 0
|
assert float(state.state) == 0
|
||||||
|
|
||||||
|
|
||||||
async def test_input_number_context(hass):
|
async def test_input_number_context(hass, hass_admin_user):
|
||||||
"""Test that input_number context works."""
|
"""Test that input_number context works."""
|
||||||
assert await async_setup_component(hass, 'input_number', {
|
assert await async_setup_component(hass, 'input_number', {
|
||||||
'input_number': {
|
'input_number': {
|
||||||
|
@ -282,9 +282,9 @@ async def test_input_number_context(hass):
|
||||||
|
|
||||||
await hass.services.async_call('input_number', 'increment', {
|
await hass.services.async_call('input_number', 'increment', {
|
||||||
'entity_id': state.entity_id,
|
'entity_id': state.entity_id,
|
||||||
}, True, Context(user_id='abcd'))
|
}, True, Context(user_id=hass_admin_user.id))
|
||||||
|
|
||||||
state2 = hass.states.get('input_number.b1')
|
state2 = hass.states.get('input_number.b1')
|
||||||
assert state2 is not None
|
assert state2 is not None
|
||||||
assert state.state != state2.state
|
assert state.state != state2.state
|
||||||
assert state2.context.user_id == 'abcd'
|
assert state2.context.user_id == hass_admin_user.id
|
||||||
|
|
|
@ -302,7 +302,7 @@ def test_initial_state_overrules_restore_state(hass):
|
||||||
assert state.state == 'middle option'
|
assert state.state == 'middle option'
|
||||||
|
|
||||||
|
|
||||||
async def test_input_select_context(hass):
|
async def test_input_select_context(hass, hass_admin_user):
|
||||||
"""Test that input_select context works."""
|
"""Test that input_select context works."""
|
||||||
assert await async_setup_component(hass, 'input_select', {
|
assert await async_setup_component(hass, 'input_select', {
|
||||||
'input_select': {
|
'input_select': {
|
||||||
|
@ -321,9 +321,9 @@ async def test_input_select_context(hass):
|
||||||
|
|
||||||
await hass.services.async_call('input_select', 'select_next', {
|
await hass.services.async_call('input_select', 'select_next', {
|
||||||
'entity_id': state.entity_id,
|
'entity_id': state.entity_id,
|
||||||
}, True, Context(user_id='abcd'))
|
}, True, Context(user_id=hass_admin_user.id))
|
||||||
|
|
||||||
state2 = hass.states.get('input_select.s1')
|
state2 = hass.states.get('input_select.s1')
|
||||||
assert state2 is not None
|
assert state2 is not None
|
||||||
assert state.state != state2.state
|
assert state.state != state2.state
|
||||||
assert state2.context.user_id == 'abcd'
|
assert state2.context.user_id == hass_admin_user.id
|
||||||
|
|
|
@ -184,7 +184,7 @@ def test_no_initial_state_and_no_restore_state(hass):
|
||||||
assert str(state.state) == 'unknown'
|
assert str(state.state) == 'unknown'
|
||||||
|
|
||||||
|
|
||||||
async def test_input_text_context(hass):
|
async def test_input_text_context(hass, hass_admin_user):
|
||||||
"""Test that input_text context works."""
|
"""Test that input_text context works."""
|
||||||
assert await async_setup_component(hass, 'input_text', {
|
assert await async_setup_component(hass, 'input_text', {
|
||||||
'input_text': {
|
'input_text': {
|
||||||
|
@ -200,9 +200,9 @@ async def test_input_text_context(hass):
|
||||||
await hass.services.async_call('input_text', 'set_value', {
|
await hass.services.async_call('input_text', 'set_value', {
|
||||||
'entity_id': state.entity_id,
|
'entity_id': state.entity_id,
|
||||||
'value': 'new_value',
|
'value': 'new_value',
|
||||||
}, True, Context(user_id='abcd'))
|
}, True, Context(user_id=hass_admin_user.id))
|
||||||
|
|
||||||
state2 = hass.states.get('input_text.t1')
|
state2 = hass.states.get('input_text.t1')
|
||||||
assert state2 is not None
|
assert state2 is not None
|
||||||
assert state.state != state2.state
|
assert state.state != state2.state
|
||||||
assert state2.context.user_id == 'abcd'
|
assert state2.context.user_id == hass_admin_user.id
|
||||||
|
|
|
@ -1,18 +1,49 @@
|
||||||
"""Test service helpers."""
|
"""Test service helpers."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
# To prevent circular import when running just this file
|
# To prevent circular import when running just this file
|
||||||
import homeassistant.components # noqa
|
import homeassistant.components # noqa
|
||||||
from homeassistant import core as ha, loader
|
from homeassistant import core as ha, loader, exceptions
|
||||||
from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID
|
from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID
|
||||||
from homeassistant.helpers import service, template
|
from homeassistant.helpers import service, template
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
from homeassistant.auth.permissions import PolicyPermissions
|
||||||
|
|
||||||
from tests.common import get_test_home_assistant, mock_service
|
from tests.common import get_test_home_assistant, mock_service, mock_coro
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_service_platform_call():
|
||||||
|
"""Mock service platform call."""
|
||||||
|
with patch('homeassistant.helpers.service._handle_service_platform_call',
|
||||||
|
side_effect=lambda *args: mock_coro()) as mock_call:
|
||||||
|
yield mock_call
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_entities():
|
||||||
|
"""Return mock entities in an ordered dict."""
|
||||||
|
kitchen = Mock(
|
||||||
|
entity_id='light.kitchen',
|
||||||
|
available=True,
|
||||||
|
should_poll=False,
|
||||||
|
)
|
||||||
|
living_room = Mock(
|
||||||
|
entity_id='light.living_room',
|
||||||
|
available=True,
|
||||||
|
should_poll=False,
|
||||||
|
)
|
||||||
|
entities = OrderedDict()
|
||||||
|
entities[kitchen.entity_id] = kitchen
|
||||||
|
entities[living_room.entity_id] = living_room
|
||||||
|
return entities
|
||||||
|
|
||||||
|
|
||||||
class TestServiceHelpers(unittest.TestCase):
|
class TestServiceHelpers(unittest.TestCase):
|
||||||
|
@ -179,3 +210,99 @@ def test_async_get_all_descriptions(hass):
|
||||||
|
|
||||||
assert 'description' in descriptions[logger.DOMAIN]['set_level']
|
assert 'description' in descriptions[logger.DOMAIN]['set_level']
|
||||||
assert 'fields' in descriptions[logger.DOMAIN]['set_level']
|
assert 'fields' in descriptions[logger.DOMAIN]['set_level']
|
||||||
|
|
||||||
|
|
||||||
|
async def test_call_context_user_not_exist(hass):
|
||||||
|
"""Check we don't allow deleted users to do things."""
|
||||||
|
with pytest.raises(exceptions.UnknownUser) as err:
|
||||||
|
await service.entity_service_call(hass, [], Mock(), ha.ServiceCall(
|
||||||
|
'test_domain', 'test_service', context=ha.Context(
|
||||||
|
user_id='non-existing')))
|
||||||
|
|
||||||
|
assert err.value.context.user_id == 'non-existing'
|
||||||
|
|
||||||
|
|
||||||
|
async def test_call_context_target_all(hass, mock_service_platform_call,
|
||||||
|
mock_entities):
|
||||||
|
"""Check we only target allowed entities if targetting all."""
|
||||||
|
with patch('homeassistant.auth.AuthManager.async_get_user',
|
||||||
|
return_value=mock_coro(Mock(permissions=PolicyPermissions({
|
||||||
|
'entities': {
|
||||||
|
'entity_ids': {
|
||||||
|
'light.kitchen': True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})))):
|
||||||
|
await service.entity_service_call(hass, [
|
||||||
|
Mock(entities=mock_entities)
|
||||||
|
], Mock(), ha.ServiceCall('test_domain', 'test_service',
|
||||||
|
context=ha.Context(user_id='mock-id')))
|
||||||
|
|
||||||
|
assert len(mock_service_platform_call.mock_calls) == 1
|
||||||
|
entities = mock_service_platform_call.mock_calls[0][1][2]
|
||||||
|
assert entities == [mock_entities['light.kitchen']]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_call_context_target_specific(hass, mock_service_platform_call,
|
||||||
|
mock_entities):
|
||||||
|
"""Check targeting specific entities."""
|
||||||
|
with patch('homeassistant.auth.AuthManager.async_get_user',
|
||||||
|
return_value=mock_coro(Mock(permissions=PolicyPermissions({
|
||||||
|
'entities': {
|
||||||
|
'entity_ids': {
|
||||||
|
'light.kitchen': True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})))):
|
||||||
|
await service.entity_service_call(hass, [
|
||||||
|
Mock(entities=mock_entities)
|
||||||
|
], Mock(), ha.ServiceCall('test_domain', 'test_service', {
|
||||||
|
'entity_id': 'light.kitchen'
|
||||||
|
}, context=ha.Context(user_id='mock-id')))
|
||||||
|
|
||||||
|
assert len(mock_service_platform_call.mock_calls) == 1
|
||||||
|
entities = mock_service_platform_call.mock_calls[0][1][2]
|
||||||
|
assert entities == [mock_entities['light.kitchen']]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_call_context_target_specific_no_auth(
|
||||||
|
hass, mock_service_platform_call, mock_entities):
|
||||||
|
"""Check targeting specific entities without auth."""
|
||||||
|
with pytest.raises(exceptions.Unauthorized) as err:
|
||||||
|
with patch('homeassistant.auth.AuthManager.async_get_user',
|
||||||
|
return_value=mock_coro(Mock(
|
||||||
|
permissions=PolicyPermissions({})))):
|
||||||
|
await service.entity_service_call(hass, [
|
||||||
|
Mock(entities=mock_entities)
|
||||||
|
], Mock(), ha.ServiceCall('test_domain', 'test_service', {
|
||||||
|
'entity_id': 'light.kitchen'
|
||||||
|
}, context=ha.Context(user_id='mock-id')))
|
||||||
|
|
||||||
|
assert err.value.context.user_id == 'mock-id'
|
||||||
|
assert err.value.entity_id == 'light.kitchen'
|
||||||
|
|
||||||
|
|
||||||
|
async def test_call_no_context_target_all(hass, mock_service_platform_call,
|
||||||
|
mock_entities):
|
||||||
|
"""Check we target all if no user context given."""
|
||||||
|
await service.entity_service_call(hass, [
|
||||||
|
Mock(entities=mock_entities)
|
||||||
|
], Mock(), ha.ServiceCall('test_domain', 'test_service'))
|
||||||
|
|
||||||
|
assert len(mock_service_platform_call.mock_calls) == 1
|
||||||
|
entities = mock_service_platform_call.mock_calls[0][1][2]
|
||||||
|
assert entities == list(mock_entities.values())
|
||||||
|
|
||||||
|
|
||||||
|
async def test_call_no_context_target_specific(
|
||||||
|
hass, mock_service_platform_call, mock_entities):
|
||||||
|
"""Check we can target specified entities."""
|
||||||
|
await service.entity_service_call(hass, [
|
||||||
|
Mock(entities=mock_entities)
|
||||||
|
], Mock(), ha.ServiceCall('test_domain', 'test_service', {
|
||||||
|
'entity_id': ['light.kitchen', 'light.non-existing']
|
||||||
|
}))
|
||||||
|
|
||||||
|
assert len(mock_service_platform_call.mock_calls) == 1
|
||||||
|
entities = mock_service_platform_call.mock_calls[0][1][2]
|
||||||
|
assert entities == [mock_entities['light.kitchen']]
|
||||||
|
|
Loading…
Reference in New Issue