Add foundation for integration services (#30813)

* Add foundation for integration services

* Fix tests

* Remove async_get_platform

* Migrate Sonos partially to EntityPlatform.async_register_entity_service

* Tweaks

* Move other Sonos services to media player domain

* Move other Sonos services to media player domain

* Address comments

* Remove lock

* Fix typos

* Use make_entity_service_schema

* Add area extraction to async_extract_entities

Co-authored-by: Anders Melchiorsen <amelchio@nogoto.net>
pull/30993/head
Paulus Schoutsen 2020-01-19 17:55:18 -08:00 committed by GitHub
parent f20b3515f2
commit 0c3ffbe282
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 339 additions and 289 deletions

View File

@ -1,13 +1,10 @@
"""Support to embed Sonos."""
import asyncio
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.components.media_player import DOMAIN as MP_DOMAIN
from homeassistant.const import ATTR_ENTITY_ID, ATTR_TIME, CONF_HOSTS
from homeassistant.const import CONF_HOSTS
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send
from .const import DOMAIN
@ -33,91 +30,12 @@ CONFIG_SCHEMA = vol.Schema(
extra=vol.ALLOW_EXTRA,
)
SERVICE_JOIN = "join"
SERVICE_UNJOIN = "unjoin"
SERVICE_SNAPSHOT = "snapshot"
SERVICE_RESTORE = "restore"
SERVICE_SET_TIMER = "set_sleep_timer"
SERVICE_CLEAR_TIMER = "clear_sleep_timer"
SERVICE_UPDATE_ALARM = "update_alarm"
SERVICE_SET_OPTION = "set_option"
SERVICE_PLAY_QUEUE = "play_queue"
ATTR_SLEEP_TIME = "sleep_time"
ATTR_ALARM_ID = "alarm_id"
ATTR_VOLUME = "volume"
ATTR_ENABLED = "enabled"
ATTR_INCLUDE_LINKED_ZONES = "include_linked_zones"
ATTR_MASTER = "master"
ATTR_WITH_GROUP = "with_group"
ATTR_NIGHT_SOUND = "night_sound"
ATTR_SPEECH_ENHANCE = "speech_enhance"
ATTR_QUEUE_POSITION = "queue_position"
SONOS_JOIN_SCHEMA = vol.Schema(
{
vol.Required(ATTR_MASTER): cv.entity_id,
vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids,
}
)
SONOS_UNJOIN_SCHEMA = vol.Schema({vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids})
SONOS_STATES_SCHEMA = vol.Schema(
{
vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Optional(ATTR_WITH_GROUP, default=True): cv.boolean,
}
)
SONOS_SET_TIMER_SCHEMA = vol.Schema(
{
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Required(ATTR_SLEEP_TIME): vol.All(
vol.Coerce(int), vol.Range(min=0, max=86399)
),
}
)
SONOS_CLEAR_TIMER_SCHEMA = vol.Schema(
{vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids}
)
SONOS_UPDATE_ALARM_SCHEMA = vol.Schema(
{
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Required(ATTR_ALARM_ID): cv.positive_int,
vol.Optional(ATTR_TIME): cv.time,
vol.Optional(ATTR_VOLUME): cv.small_float,
vol.Optional(ATTR_ENABLED): cv.boolean,
vol.Optional(ATTR_INCLUDE_LINKED_ZONES): cv.boolean,
}
)
SONOS_SET_OPTION_SCHEMA = vol.Schema(
{
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Optional(ATTR_NIGHT_SOUND): cv.boolean,
vol.Optional(ATTR_SPEECH_ENHANCE): cv.boolean,
}
)
SONOS_PLAY_QUEUE_SCHEMA = vol.Schema(
{
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
vol.Optional(ATTR_QUEUE_POSITION, default=0): cv.positive_int,
}
)
DATA_SERVICE_EVENT = "sonos_service_idle"
async def async_setup(hass, config):
"""Set up the Sonos component."""
conf = config.get(DOMAIN)
hass.data[DOMAIN] = conf or {}
hass.data[DATA_SERVICE_EVENT] = asyncio.Event()
if conf is not None:
hass.async_create_task(
@ -126,48 +44,6 @@ async def async_setup(hass, config):
)
)
async def service_handle(service):
"""Dispatch a service call."""
hass.data[DATA_SERVICE_EVENT].clear()
async_dispatcher_send(hass, DOMAIN, service.service, service.data)
await hass.data[DATA_SERVICE_EVENT].wait()
hass.services.async_register(
DOMAIN, SERVICE_JOIN, service_handle, schema=SONOS_JOIN_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_UNJOIN, service_handle, schema=SONOS_UNJOIN_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_SNAPSHOT, service_handle, schema=SONOS_STATES_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_RESTORE, service_handle, schema=SONOS_STATES_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_SET_TIMER, service_handle, schema=SONOS_SET_TIMER_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_CLEAR_TIMER, service_handle, schema=SONOS_CLEAR_TIMER_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_UPDATE_ALARM, service_handle, schema=SONOS_UPDATE_ALARM_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_SET_OPTION, service_handle, schema=SONOS_SET_OPTION_SCHEMA
)
hass.services.async_register(
DOMAIN, SERVICE_PLAY_QUEUE, service_handle, schema=SONOS_PLAY_QUEUE_SCHEMA
)
return True

View File

@ -11,6 +11,7 @@ import pysonos
from pysonos import alarms
from pysonos.exceptions import SoCoException, SoCoUPnPException
import pysonos.snapshot
import voluptuous as vol
from homeassistant.components.media_player import MediaPlayerDevice
from homeassistant.components.media_player.const import (
@ -30,42 +31,16 @@ from homeassistant.components.media_player.const import (
SUPPORT_VOLUME_MUTE,
SUPPORT_VOLUME_SET,
)
from homeassistant.const import (
ENTITY_MATCH_ALL,
STATE_IDLE,
STATE_PAUSED,
STATE_PLAYING,
)
from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.const import ATTR_TIME, STATE_IDLE, STATE_PAUSED, STATE_PLAYING
from homeassistant.core import ServiceCall, callback
from homeassistant.helpers import config_validation as cv, entity_platform, service
from homeassistant.util.dt import utcnow
from . import (
ATTR_ALARM_ID,
ATTR_ENABLED,
ATTR_INCLUDE_LINKED_ZONES,
ATTR_MASTER,
ATTR_NIGHT_SOUND,
ATTR_QUEUE_POSITION,
ATTR_SLEEP_TIME,
ATTR_SPEECH_ENHANCE,
ATTR_TIME,
ATTR_VOLUME,
ATTR_WITH_GROUP,
CONF_ADVERTISE_ADDR,
CONF_HOSTS,
CONF_INTERFACE_ADDR,
DATA_SERVICE_EVENT,
DOMAIN as SONOS_DOMAIN,
SERVICE_CLEAR_TIMER,
SERVICE_JOIN,
SERVICE_PLAY_QUEUE,
SERVICE_RESTORE,
SERVICE_SET_OPTION,
SERVICE_SET_TIMER,
SERVICE_SNAPSHOT,
SERVICE_UNJOIN,
SERVICE_UPDATE_ALARM,
)
_LOGGER = logging.getLogger(__name__)
@ -97,6 +72,27 @@ ATTR_SONOS_GROUP = "sonos_group"
UPNP_ERRORS_TO_IGNORE = ["701", "711", "712"]
SERVICE_JOIN = "join"
SERVICE_UNJOIN = "unjoin"
SERVICE_SNAPSHOT = "snapshot"
SERVICE_RESTORE = "restore"
SERVICE_SET_TIMER = "set_sleep_timer"
SERVICE_CLEAR_TIMER = "clear_sleep_timer"
SERVICE_UPDATE_ALARM = "update_alarm"
SERVICE_SET_OPTION = "set_option"
SERVICE_PLAY_QUEUE = "play_queue"
ATTR_SLEEP_TIME = "sleep_time"
ATTR_ALARM_ID = "alarm_id"
ATTR_VOLUME = "volume"
ATTR_ENABLED = "enabled"
ATTR_INCLUDE_LINKED_ZONES = "include_linked_zones"
ATTR_MASTER = "master"
ATTR_WITH_GROUP = "with_group"
ATTR_NIGHT_SOUND = "night_sound"
ATTR_SPEECH_ENHANCE = "speech_enhance"
ATTR_QUEUE_POSITION = "queue_position"
class SonosData:
"""Storage class for platform global data."""
@ -176,46 +172,101 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
_LOGGER.debug("Adding discovery job")
hass.async_add_executor_job(_discovery)
async def async_service_handle(service, data):
platform = entity_platform.current_platform.get()
async def async_service_handle(service_call: ServiceCall):
"""Handle dispatched services."""
entity_ids = data.get("entity_id")
entities = hass.data[DATA_SONOS].entities
if entity_ids and entity_ids != ENTITY_MATCH_ALL:
entities = [e for e in entities if e.entity_id in entity_ids]
entities = await platform.async_extract_from_service(service_call)
if service == SERVICE_JOIN:
master = [
e
for e in hass.data[DATA_SONOS].entities
if e.entity_id == data[ATTR_MASTER]
]
if not entities:
return
if service_call.service == SERVICE_JOIN:
master = platform.entities.get(service_call.data[ATTR_MASTER])
if master:
await SonosEntity.join_multi(hass, master[0], entities)
elif service == SERVICE_UNJOIN:
await SonosEntity.join_multi(hass, master, entities)
else:
_LOGGER.error(
"Invalid master specified for join service: %s",
service_call.data[ATTR_MASTER],
)
elif service_call.service == SERVICE_UNJOIN:
await SonosEntity.unjoin_multi(hass, entities)
elif service == SERVICE_SNAPSHOT:
await SonosEntity.snapshot_multi(hass, entities, data[ATTR_WITH_GROUP])
elif service == SERVICE_RESTORE:
await SonosEntity.restore_multi(hass, entities, data[ATTR_WITH_GROUP])
else:
for entity in entities:
if service == SERVICE_SET_TIMER:
call = entity.set_sleep_timer
elif service == SERVICE_CLEAR_TIMER:
call = entity.clear_sleep_timer
elif service == SERVICE_UPDATE_ALARM:
call = entity.set_alarm
elif service == SERVICE_SET_OPTION:
call = entity.set_option
elif service == SERVICE_PLAY_QUEUE:
call = entity.play_queue
elif service_call.service == SERVICE_SNAPSHOT:
await SonosEntity.snapshot_multi(
hass, entities, service_call.data[ATTR_WITH_GROUP]
)
elif service_call.service == SERVICE_RESTORE:
await SonosEntity.restore_multi(
hass, entities, service_call.data[ATTR_WITH_GROUP]
)
hass.async_add_executor_job(call, data)
service.async_register_admin_service(
hass,
SONOS_DOMAIN,
SERVICE_JOIN,
async_service_handle,
cv.make_entity_service_schema({vol.Required(ATTR_MASTER): cv.entity_id}),
)
# We are ready for the next service call
hass.data[DATA_SERVICE_EVENT].set()
service.async_register_admin_service(
hass,
SONOS_DOMAIN,
SERVICE_UNJOIN,
async_service_handle,
cv.make_entity_service_schema({}),
)
async_dispatcher_connect(hass, SONOS_DOMAIN, async_service_handle)
join_unjoin_schema = cv.make_entity_service_schema(
{vol.Optional(ATTR_WITH_GROUP, default=True): cv.boolean}
)
service.async_register_admin_service(
hass, SONOS_DOMAIN, SERVICE_SNAPSHOT, async_service_handle, join_unjoin_schema
)
service.async_register_admin_service(
hass, SONOS_DOMAIN, SERVICE_RESTORE, async_service_handle, join_unjoin_schema
)
platform.async_register_entity_service(
SERVICE_SET_TIMER,
{
vol.Required(ATTR_SLEEP_TIME): vol.All(
vol.Coerce(int), vol.Range(min=0, max=86399)
)
},
"set_sleep_timer",
)
platform.async_register_entity_service(SERVICE_CLEAR_TIMER, {}, "clear_sleep_timer")
platform.async_register_entity_service(
SERVICE_UPDATE_ALARM,
{
vol.Required(ATTR_ALARM_ID): cv.positive_int,
vol.Optional(ATTR_TIME): cv.time,
vol.Optional(ATTR_VOLUME): cv.small_float,
vol.Optional(ATTR_ENABLED): cv.boolean,
vol.Optional(ATTR_INCLUDE_LINKED_ZONES): cv.boolean,
},
"set_alarm",
)
platform.async_register_entity_service(
SERVICE_SET_OPTION,
{
vol.Optional(ATTR_NIGHT_SOUND): cv.boolean,
vol.Optional(ATTR_SPEECH_ENHANCE): cv.boolean,
},
"set_option",
)
platform.async_register_entity_service(
SERVICE_PLAY_QUEUE,
{vol.Optional(ATTR_QUEUE_POSITION): cv.positive_int},
"play_queue",
)
class _ProcessSonosEventQueue:
@ -480,10 +531,10 @@ class SonosEntity(MediaPlayerDevice):
player = self.soco
def subscribe(service, action):
def subscribe(sonos_service, action):
"""Add a subscription to a pysonos service."""
queue = _ProcessSonosEventQueue(action)
sub = service.subscribe(auto_renew=True, event_queue=queue)
sub = sonos_service.subscribe(auto_renew=True, event_queue=queue)
self._subscriptions.append(sub)
subscribe(player.avTransport, self.update_media)
@ -1147,52 +1198,53 @@ class SonosEntity(MediaPlayerDevice):
@soco_error()
@soco_coordinator
def set_sleep_timer(self, data):
def set_sleep_timer(self, sleep_time):
"""Set the timer on the player."""
self.soco.set_sleep_timer(data[ATTR_SLEEP_TIME])
self.soco.set_sleep_timer(sleep_time)
@soco_error()
@soco_coordinator
def clear_sleep_timer(self, data):
def clear_sleep_timer(self):
"""Clear the timer on the player."""
self.soco.set_sleep_timer(None)
@soco_error()
@soco_coordinator
def set_alarm(self, data):
def set_alarm(
self, alarm_id, time=None, volume=None, enabled=None, include_linked_zones=None
):
"""Set the alarm clock on the player."""
alarm = None
for one_alarm in alarms.get_alarms(self.soco):
# pylint: disable=protected-access
if one_alarm._alarm_id == str(data[ATTR_ALARM_ID]):
if one_alarm._alarm_id == str(alarm_id):
alarm = one_alarm
if alarm is None:
_LOGGER.warning("did not find alarm with id %s", data[ATTR_ALARM_ID])
_LOGGER.warning("did not find alarm with id %s", alarm_id)
return
if ATTR_TIME in data:
alarm.start_time = data[ATTR_TIME]
if ATTR_VOLUME in data:
alarm.volume = int(data[ATTR_VOLUME] * 100)
if ATTR_ENABLED in data:
alarm.enabled = data[ATTR_ENABLED]
if ATTR_INCLUDE_LINKED_ZONES in data:
alarm.include_linked_zones = data[ATTR_INCLUDE_LINKED_ZONES]
if time is not None:
alarm.start_time = time
if volume is not None:
alarm.volume = int(volume * 100)
if enabled is not None:
alarm.enabled = enabled
if include_linked_zones is not None:
alarm.include_linked_zones = include_linked_zones
alarm.save()
@soco_error()
def set_option(self, data):
def set_option(self, night_sound=None, speech_enhance=None):
"""Modify playback options."""
if ATTR_NIGHT_SOUND in data and self._night_sound is not None:
self.soco.night_mode = data[ATTR_NIGHT_SOUND]
if night_sound is not None and self._night_sound is not None:
self.soco.night_mode = night_sound
if ATTR_SPEECH_ENHANCE in data and self._speech_enhance is not None:
self.soco.dialog_mode = data[ATTR_SPEECH_ENHANCE]
if speech_enhance is not None and self._speech_enhance is not None:
self.soco.dialog_mode = speech_enhance
@soco_error()
def play_queue(self, data):
def play_queue(self, queue_position=0):
"""Start playing the queue."""
self.soco.play_from_queue(data[ATTR_QUEUE_POSITION])
self.soco.play_from_queue(queue_position)
@property
def device_state_attributes(self):

View File

@ -6,17 +6,15 @@ import logging
from homeassistant import config as conf_util
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_ENTITY_NAMESPACE,
CONF_SCAN_INTERVAL,
ENTITY_MATCH_ALL,
)
from homeassistant.const import CONF_ENTITY_NAMESPACE, CONF_SCAN_INTERVAL
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.config_validation import make_entity_service_schema
from homeassistant.helpers.service import async_extract_entity_ids
from homeassistant.helpers import (
config_per_platform,
config_validation as cv,
discovery,
service,
)
from homeassistant.loader import async_get_integration, bind_hass
from homeassistant.setup import async_prepare_setup_platform
@ -166,39 +164,27 @@ class EntityComponent:
await platform.async_reset()
return True
async def async_extract_from_service(self, service, expand_group=True):
async def async_extract_from_service(self, service_call, expand_group=True):
"""Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown.
This method must be run in the event loop.
"""
data_ent_id = service.data.get(ATTR_ENTITY_ID)
if data_ent_id is None:
return []
if data_ent_id == ENTITY_MATCH_ALL:
return [entity for entity in self.entities if entity.available]
entity_ids = await async_extract_entity_ids(self.hass, service, expand_group)
return [
entity
for entity in self.entities
if entity.available and entity.entity_id in entity_ids
]
return await service.async_extract_entities(
self.hass, self.entities, service_call, expand_group
)
@callback
def async_register_entity_service(self, name, schema, func, required_features=None):
"""Register an entity service."""
if isinstance(schema, dict):
schema = make_entity_service_schema(schema)
schema = cv.make_entity_service_schema(schema)
async def handle_service(call):
"""Handle the service."""
service_name = f"{self.domain}.{name}"
await self.hass.helpers.service.entity_service_call(
self._platforms.values(), func, call, service_name, required_features
self._platforms.values(), func, call, required_features
)
self.hass.services.async_register(self.domain, name, handle_service, schema)

View File

@ -7,6 +7,7 @@ from typing import Optional
from homeassistant.const import DEVICE_DEFAULT_NAME
from homeassistant.core import callback, split_entity_id, valid_entity_id
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.helpers import config_validation as cv, service
from homeassistant.util.async_ import run_callback_threadsafe
from .entity_registry import DISABLED_INTEGRATION
@ -194,7 +195,11 @@ class EntityPlatform:
)
return False
except Exception: # pylint: disable=broad-except
logger.exception("Error while setting up platform %s", self.platform_name)
logger.exception(
"Error while setting up %s platform for %s",
self.platform_name,
self.domain,
)
return False
finally:
warn_task.cancel()
@ -449,6 +454,33 @@ class EntityPlatform:
self._async_unsub_polling()
self._async_unsub_polling = None
async def async_extract_from_service(self, service_call, expand_group=True):
"""Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown.
This method must be run in the event loop.
"""
return await service.async_extract_entities(
self.hass, self.entities.values(), service_call, expand_group
)
@callback
def async_register_entity_service(self, name, schema, func, required_features=None):
"""Register an entity service."""
if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
async def handle_service(call):
"""Handle the service."""
await service.entity_service_call(
self.hass, [self], func, call, required_features
)
self.hass.services.async_register(
self.platform_name, name, handle_service, schema
)
async def _update_entity_states(self, now: datetime) -> None:
"""Update the states of all the polling entities.

View File

@ -108,13 +108,31 @@ def extract_entity_ids(hass, service_call, expand_group=True):
).result()
@bind_hass
async def async_extract_entities(hass, entities, service_call, expand_group=True):
"""Extract a list of entity objects from a service call.
Will convert group entity ids to the entity ids it represents.
"""
data_ent_id = service_call.data.get(ATTR_ENTITY_ID)
if data_ent_id == ENTITY_MATCH_ALL:
return [entity for entity in entities if entity.available]
entity_ids = await async_extract_entity_ids(hass, service_call, expand_group)
return [
entity
for entity in entities
if entity.available and entity.entity_id in entity_ids
]
@bind_hass
async def async_extract_entity_ids(hass, service_call, expand_group=True):
"""Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents.
Async friendly.
"""
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
area_ids = service_call.data.get(ATTR_AREA_ID)
@ -244,9 +262,7 @@ def async_set_service_schema(hass, domain, service, schema):
@bind_hass
async def entity_service_call(
hass, platforms, func, call, service_name="", required_features=None
):
async def entity_service_call(hass, platforms, func, call, required_features=None):
"""Handle an entity service call.
Calls all platforms simultaneously.

View File

@ -23,6 +23,7 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.setup import async_setup_component
from tests.common import (
MockEntity,
get_test_home_assistant,
mock_coro,
mock_device_registry,
@ -64,6 +65,54 @@ def mock_entities():
return entities
@pytest.fixture
def area_mock(hass):
"""Mock including area info."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set("light.Kitchen", STATE_OFF)
device_in_area = dev_reg.DeviceEntry(area_id="test-area")
device_no_area = dev_reg.DeviceEntry()
device_diff_area = dev_reg.DeviceEntry(area_id="diff-area")
mock_device_registry(
hass,
{
device_in_area.id: device_in_area,
device_no_area.id: device_no_area,
device_diff_area.id: device_diff_area,
},
)
entity_in_area = ent_reg.RegistryEntry(
entity_id="light.in_area",
unique_id="in-area-id",
platform="test",
device_id=device_in_area.id,
)
entity_no_area = ent_reg.RegistryEntry(
entity_id="light.no_area",
unique_id="no-area-id",
platform="test",
device_id=device_no_area.id,
)
entity_diff_area = ent_reg.RegistryEntry(
entity_id="light.diff_area",
unique_id="diff-area-id",
platform="test",
device_id=device_diff_area.id,
)
mock_registry(
hass,
{
entity_in_area.entity_id: entity_in_area,
entity_no_area.entity_id: entity_no_area,
entity_diff_area.entity_id: entity_diff_area,
},
)
class TestServiceHelpers(unittest.TestCase):
"""Test the Home Assistant service helpers."""
@ -204,52 +253,8 @@ async def test_extract_entity_ids(hass):
)
async def test_extract_entity_ids_from_area(hass):
async def test_extract_entity_ids_from_area(hass, area_mock):
"""Test extract_entity_ids method with areas."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set("light.Kitchen", STATE_OFF)
device_in_area = dev_reg.DeviceEntry(area_id="test-area")
device_no_area = dev_reg.DeviceEntry()
device_diff_area = dev_reg.DeviceEntry(area_id="diff-area")
mock_device_registry(
hass,
{
device_in_area.id: device_in_area,
device_no_area.id: device_no_area,
device_diff_area.id: device_diff_area,
},
)
entity_in_area = ent_reg.RegistryEntry(
entity_id="light.in_area",
unique_id="in-area-id",
platform="test",
device_id=device_in_area.id,
)
entity_no_area = ent_reg.RegistryEntry(
entity_id="light.no_area",
unique_id="no-area-id",
platform="test",
device_id=device_no_area.id,
)
entity_diff_area = ent_reg.RegistryEntry(
entity_id="light.diff_area",
unique_id="diff-area-id",
platform="test",
device_id=device_diff_area.id,
)
mock_registry(
hass,
{
entity_in_area.entity_id: entity_in_area,
entity_no_area.entity_id: entity_no_area,
entity_diff_area.entity_id: entity_diff_area,
},
)
call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"})
assert {"light.in_area"} == await service.async_extract_entity_ids(hass, call)
@ -678,3 +683,86 @@ async def test_domain_control_no_user(hass, mock_entities):
)
assert len(calls) == 1
async def test_extract_from_service_available_device(hass):
"""Test the extraction of entity from service and device is available."""
entities = [
MockEntity(name="test_1", entity_id="test_domain.test_1"),
MockEntity(name="test_2", entity_id="test_domain.test_2", available=False),
MockEntity(name="test_3", entity_id="test_domain.test_3"),
MockEntity(name="test_4", entity_id="test_domain.test_4", available=False),
]
call_1 = ha.ServiceCall("test", "service", data={"entity_id": ENTITY_MATCH_ALL})
assert ["test_domain.test_1", "test_domain.test_3"] == [
ent.entity_id
for ent in (await service.async_extract_entities(hass, entities, call_1))
]
call_2 = ha.ServiceCall(
"test",
"service",
data={"entity_id": ["test_domain.test_3", "test_domain.test_4"]},
)
assert ["test_domain.test_3"] == [
ent.entity_id
for ent in (await service.async_extract_entities(hass, entities, call_2))
]
async def test_extract_from_service_empty_if_no_entity_id(hass):
"""Test the extraction from service without specifying entity."""
entities = [
MockEntity(name="test_1", entity_id="test_domain.test_1"),
MockEntity(name="test_2", entity_id="test_domain.test_2"),
]
call = ha.ServiceCall("test", "service")
assert [] == [
ent.entity_id
for ent in (await service.async_extract_entities(hass, entities, call))
]
async def test_extract_from_service_filter_out_non_existing_entities(hass):
"""Test the extraction of non existing entities from service."""
entities = [
MockEntity(name="test_1", entity_id="test_domain.test_1"),
MockEntity(name="test_2", entity_id="test_domain.test_2"),
]
call = ha.ServiceCall(
"test",
"service",
{"entity_id": ["test_domain.test_2", "test_domain.non_exist"]},
)
assert ["test_domain.test_2"] == [
ent.entity_id
for ent in (await service.async_extract_entities(hass, entities, call))
]
async def test_extract_from_service_area_id(hass, area_mock):
"""Test the extraction using area ID as reference."""
entities = [
MockEntity(name="in_area", entity_id="light.in_area"),
MockEntity(name="no_area", entity_id="light.no_area"),
MockEntity(name="diff_area", entity_id="light.diff_area"),
]
call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"})
extracted = await service.async_extract_entities(hass, entities, call)
assert len(extracted) == 1
assert extracted[0].entity_id == "light.in_area"
call = ha.ServiceCall("light", "turn_on", {"area_id": ["test-area", "diff-area"]})
extracted = await service.async_extract_entities(hass, entities, call)
assert len(extracted) == 2
assert sorted(ent.entity_id for ent in extracted) == [
"light.diff_area",
"light.in_area",
]