From 0c3ffbe282fe86ada44ea09b87b90a1f258a562a Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 19 Jan 2020 17:55:18 -0800 Subject: [PATCH] Add foundation for integration services (#30813) * Add foundation for integration services * Fix tests * Remove async_get_platform * Migrate Sonos partially to EntityPlatform.async_register_entity_service * Tweaks * Move other Sonos services to media player domain * Move other Sonos services to media player domain * Address comments * Remove lock * Fix typos * Use make_entity_service_schema * Add area extraction to async_extract_entities Co-authored-by: Anders Melchiorsen --- homeassistant/components/sonos/__init__.py | 126 +--------- .../components/sonos/media_player.py | 224 +++++++++++------- homeassistant/helpers/entity_component.py | 40 +--- homeassistant/helpers/entity_platform.py | 34 ++- homeassistant/helpers/service.py | 26 +- tests/helpers/test_service.py | 178 ++++++++++---- 6 files changed, 339 insertions(+), 289 deletions(-) diff --git a/homeassistant/components/sonos/__init__.py b/homeassistant/components/sonos/__init__.py index d2c6210f01c..c3a977e32e1 100644 --- a/homeassistant/components/sonos/__init__.py +++ b/homeassistant/components/sonos/__init__.py @@ -1,13 +1,10 @@ """Support to embed Sonos.""" -import asyncio - import voluptuous as vol from homeassistant import config_entries from homeassistant.components.media_player import DOMAIN as MP_DOMAIN -from homeassistant.const import ATTR_ENTITY_ID, ATTR_TIME, CONF_HOSTS +from homeassistant.const import CONF_HOSTS from homeassistant.helpers import config_validation as cv -from homeassistant.helpers.dispatcher import async_dispatcher_send from .const import DOMAIN @@ -33,91 +30,12 @@ CONFIG_SCHEMA = vol.Schema( extra=vol.ALLOW_EXTRA, ) -SERVICE_JOIN = "join" -SERVICE_UNJOIN = "unjoin" -SERVICE_SNAPSHOT = "snapshot" -SERVICE_RESTORE = "restore" -SERVICE_SET_TIMER = "set_sleep_timer" -SERVICE_CLEAR_TIMER = "clear_sleep_timer" -SERVICE_UPDATE_ALARM = "update_alarm" -SERVICE_SET_OPTION = "set_option" -SERVICE_PLAY_QUEUE = "play_queue" - -ATTR_SLEEP_TIME = "sleep_time" -ATTR_ALARM_ID = "alarm_id" -ATTR_VOLUME = "volume" -ATTR_ENABLED = "enabled" -ATTR_INCLUDE_LINKED_ZONES = "include_linked_zones" -ATTR_MASTER = "master" -ATTR_WITH_GROUP = "with_group" -ATTR_NIGHT_SOUND = "night_sound" -ATTR_SPEECH_ENHANCE = "speech_enhance" -ATTR_QUEUE_POSITION = "queue_position" - -SONOS_JOIN_SCHEMA = vol.Schema( - { - vol.Required(ATTR_MASTER): cv.entity_id, - vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids, - } -) - -SONOS_UNJOIN_SCHEMA = vol.Schema({vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids}) - -SONOS_STATES_SCHEMA = vol.Schema( - { - vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids, - vol.Optional(ATTR_WITH_GROUP, default=True): cv.boolean, - } -) - -SONOS_SET_TIMER_SCHEMA = vol.Schema( - { - vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids, - vol.Required(ATTR_SLEEP_TIME): vol.All( - vol.Coerce(int), vol.Range(min=0, max=86399) - ), - } -) - -SONOS_CLEAR_TIMER_SCHEMA = vol.Schema( - {vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids} -) - -SONOS_UPDATE_ALARM_SCHEMA = vol.Schema( - { - vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids, - vol.Required(ATTR_ALARM_ID): cv.positive_int, - vol.Optional(ATTR_TIME): cv.time, - vol.Optional(ATTR_VOLUME): cv.small_float, - vol.Optional(ATTR_ENABLED): cv.boolean, - vol.Optional(ATTR_INCLUDE_LINKED_ZONES): cv.boolean, - } -) - -SONOS_SET_OPTION_SCHEMA = vol.Schema( - { - vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids, - vol.Optional(ATTR_NIGHT_SOUND): cv.boolean, - vol.Optional(ATTR_SPEECH_ENHANCE): cv.boolean, - } -) - -SONOS_PLAY_QUEUE_SCHEMA = vol.Schema( - { - vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids, - vol.Optional(ATTR_QUEUE_POSITION, default=0): cv.positive_int, - } -) - -DATA_SERVICE_EVENT = "sonos_service_idle" - async def async_setup(hass, config): """Set up the Sonos component.""" conf = config.get(DOMAIN) hass.data[DOMAIN] = conf or {} - hass.data[DATA_SERVICE_EVENT] = asyncio.Event() if conf is not None: hass.async_create_task( @@ -126,48 +44,6 @@ async def async_setup(hass, config): ) ) - async def service_handle(service): - """Dispatch a service call.""" - hass.data[DATA_SERVICE_EVENT].clear() - async_dispatcher_send(hass, DOMAIN, service.service, service.data) - await hass.data[DATA_SERVICE_EVENT].wait() - - hass.services.async_register( - DOMAIN, SERVICE_JOIN, service_handle, schema=SONOS_JOIN_SCHEMA - ) - - hass.services.async_register( - DOMAIN, SERVICE_UNJOIN, service_handle, schema=SONOS_UNJOIN_SCHEMA - ) - - hass.services.async_register( - DOMAIN, SERVICE_SNAPSHOT, service_handle, schema=SONOS_STATES_SCHEMA - ) - - hass.services.async_register( - DOMAIN, SERVICE_RESTORE, service_handle, schema=SONOS_STATES_SCHEMA - ) - - hass.services.async_register( - DOMAIN, SERVICE_SET_TIMER, service_handle, schema=SONOS_SET_TIMER_SCHEMA - ) - - hass.services.async_register( - DOMAIN, SERVICE_CLEAR_TIMER, service_handle, schema=SONOS_CLEAR_TIMER_SCHEMA - ) - - hass.services.async_register( - DOMAIN, SERVICE_UPDATE_ALARM, service_handle, schema=SONOS_UPDATE_ALARM_SCHEMA - ) - - hass.services.async_register( - DOMAIN, SERVICE_SET_OPTION, service_handle, schema=SONOS_SET_OPTION_SCHEMA - ) - - hass.services.async_register( - DOMAIN, SERVICE_PLAY_QUEUE, service_handle, schema=SONOS_PLAY_QUEUE_SCHEMA - ) - return True diff --git a/homeassistant/components/sonos/media_player.py b/homeassistant/components/sonos/media_player.py index 9ce72d87dfe..bcdb74ad438 100644 --- a/homeassistant/components/sonos/media_player.py +++ b/homeassistant/components/sonos/media_player.py @@ -11,6 +11,7 @@ import pysonos from pysonos import alarms from pysonos.exceptions import SoCoException, SoCoUPnPException import pysonos.snapshot +import voluptuous as vol from homeassistant.components.media_player import MediaPlayerDevice from homeassistant.components.media_player.const import ( @@ -30,42 +31,16 @@ from homeassistant.components.media_player.const import ( SUPPORT_VOLUME_MUTE, SUPPORT_VOLUME_SET, ) -from homeassistant.const import ( - ENTITY_MATCH_ALL, - STATE_IDLE, - STATE_PAUSED, - STATE_PLAYING, -) -from homeassistant.core import callback -from homeassistant.helpers.dispatcher import async_dispatcher_connect +from homeassistant.const import ATTR_TIME, STATE_IDLE, STATE_PAUSED, STATE_PLAYING +from homeassistant.core import ServiceCall, callback +from homeassistant.helpers import config_validation as cv, entity_platform, service from homeassistant.util.dt import utcnow from . import ( - ATTR_ALARM_ID, - ATTR_ENABLED, - ATTR_INCLUDE_LINKED_ZONES, - ATTR_MASTER, - ATTR_NIGHT_SOUND, - ATTR_QUEUE_POSITION, - ATTR_SLEEP_TIME, - ATTR_SPEECH_ENHANCE, - ATTR_TIME, - ATTR_VOLUME, - ATTR_WITH_GROUP, CONF_ADVERTISE_ADDR, CONF_HOSTS, CONF_INTERFACE_ADDR, - DATA_SERVICE_EVENT, DOMAIN as SONOS_DOMAIN, - SERVICE_CLEAR_TIMER, - SERVICE_JOIN, - SERVICE_PLAY_QUEUE, - SERVICE_RESTORE, - SERVICE_SET_OPTION, - SERVICE_SET_TIMER, - SERVICE_SNAPSHOT, - SERVICE_UNJOIN, - SERVICE_UPDATE_ALARM, ) _LOGGER = logging.getLogger(__name__) @@ -97,6 +72,27 @@ ATTR_SONOS_GROUP = "sonos_group" UPNP_ERRORS_TO_IGNORE = ["701", "711", "712"] +SERVICE_JOIN = "join" +SERVICE_UNJOIN = "unjoin" +SERVICE_SNAPSHOT = "snapshot" +SERVICE_RESTORE = "restore" +SERVICE_SET_TIMER = "set_sleep_timer" +SERVICE_CLEAR_TIMER = "clear_sleep_timer" +SERVICE_UPDATE_ALARM = "update_alarm" +SERVICE_SET_OPTION = "set_option" +SERVICE_PLAY_QUEUE = "play_queue" + +ATTR_SLEEP_TIME = "sleep_time" +ATTR_ALARM_ID = "alarm_id" +ATTR_VOLUME = "volume" +ATTR_ENABLED = "enabled" +ATTR_INCLUDE_LINKED_ZONES = "include_linked_zones" +ATTR_MASTER = "master" +ATTR_WITH_GROUP = "with_group" +ATTR_NIGHT_SOUND = "night_sound" +ATTR_SPEECH_ENHANCE = "speech_enhance" +ATTR_QUEUE_POSITION = "queue_position" + class SonosData: """Storage class for platform global data.""" @@ -176,46 +172,101 @@ async def async_setup_entry(hass, config_entry, async_add_entities): _LOGGER.debug("Adding discovery job") hass.async_add_executor_job(_discovery) - async def async_service_handle(service, data): + platform = entity_platform.current_platform.get() + + async def async_service_handle(service_call: ServiceCall): """Handle dispatched services.""" - entity_ids = data.get("entity_id") - entities = hass.data[DATA_SONOS].entities - if entity_ids and entity_ids != ENTITY_MATCH_ALL: - entities = [e for e in entities if e.entity_id in entity_ids] + entities = await platform.async_extract_from_service(service_call) - if service == SERVICE_JOIN: - master = [ - e - for e in hass.data[DATA_SONOS].entities - if e.entity_id == data[ATTR_MASTER] - ] + if not entities: + return + + if service_call.service == SERVICE_JOIN: + master = platform.entities.get(service_call.data[ATTR_MASTER]) if master: - await SonosEntity.join_multi(hass, master[0], entities) - elif service == SERVICE_UNJOIN: + await SonosEntity.join_multi(hass, master, entities) + else: + _LOGGER.error( + "Invalid master specified for join service: %s", + service_call.data[ATTR_MASTER], + ) + elif service_call.service == SERVICE_UNJOIN: await SonosEntity.unjoin_multi(hass, entities) - elif service == SERVICE_SNAPSHOT: - await SonosEntity.snapshot_multi(hass, entities, data[ATTR_WITH_GROUP]) - elif service == SERVICE_RESTORE: - await SonosEntity.restore_multi(hass, entities, data[ATTR_WITH_GROUP]) - else: - for entity in entities: - if service == SERVICE_SET_TIMER: - call = entity.set_sleep_timer - elif service == SERVICE_CLEAR_TIMER: - call = entity.clear_sleep_timer - elif service == SERVICE_UPDATE_ALARM: - call = entity.set_alarm - elif service == SERVICE_SET_OPTION: - call = entity.set_option - elif service == SERVICE_PLAY_QUEUE: - call = entity.play_queue + elif service_call.service == SERVICE_SNAPSHOT: + await SonosEntity.snapshot_multi( + hass, entities, service_call.data[ATTR_WITH_GROUP] + ) + elif service_call.service == SERVICE_RESTORE: + await SonosEntity.restore_multi( + hass, entities, service_call.data[ATTR_WITH_GROUP] + ) - hass.async_add_executor_job(call, data) + service.async_register_admin_service( + hass, + SONOS_DOMAIN, + SERVICE_JOIN, + async_service_handle, + cv.make_entity_service_schema({vol.Required(ATTR_MASTER): cv.entity_id}), + ) - # We are ready for the next service call - hass.data[DATA_SERVICE_EVENT].set() + service.async_register_admin_service( + hass, + SONOS_DOMAIN, + SERVICE_UNJOIN, + async_service_handle, + cv.make_entity_service_schema({}), + ) - async_dispatcher_connect(hass, SONOS_DOMAIN, async_service_handle) + join_unjoin_schema = cv.make_entity_service_schema( + {vol.Optional(ATTR_WITH_GROUP, default=True): cv.boolean} + ) + + service.async_register_admin_service( + hass, SONOS_DOMAIN, SERVICE_SNAPSHOT, async_service_handle, join_unjoin_schema + ) + + service.async_register_admin_service( + hass, SONOS_DOMAIN, SERVICE_RESTORE, async_service_handle, join_unjoin_schema + ) + + platform.async_register_entity_service( + SERVICE_SET_TIMER, + { + vol.Required(ATTR_SLEEP_TIME): vol.All( + vol.Coerce(int), vol.Range(min=0, max=86399) + ) + }, + "set_sleep_timer", + ) + + platform.async_register_entity_service(SERVICE_CLEAR_TIMER, {}, "clear_sleep_timer") + + platform.async_register_entity_service( + SERVICE_UPDATE_ALARM, + { + vol.Required(ATTR_ALARM_ID): cv.positive_int, + vol.Optional(ATTR_TIME): cv.time, + vol.Optional(ATTR_VOLUME): cv.small_float, + vol.Optional(ATTR_ENABLED): cv.boolean, + vol.Optional(ATTR_INCLUDE_LINKED_ZONES): cv.boolean, + }, + "set_alarm", + ) + + platform.async_register_entity_service( + SERVICE_SET_OPTION, + { + vol.Optional(ATTR_NIGHT_SOUND): cv.boolean, + vol.Optional(ATTR_SPEECH_ENHANCE): cv.boolean, + }, + "set_option", + ) + + platform.async_register_entity_service( + SERVICE_PLAY_QUEUE, + {vol.Optional(ATTR_QUEUE_POSITION): cv.positive_int}, + "play_queue", + ) class _ProcessSonosEventQueue: @@ -480,10 +531,10 @@ class SonosEntity(MediaPlayerDevice): player = self.soco - def subscribe(service, action): + def subscribe(sonos_service, action): """Add a subscription to a pysonos service.""" queue = _ProcessSonosEventQueue(action) - sub = service.subscribe(auto_renew=True, event_queue=queue) + sub = sonos_service.subscribe(auto_renew=True, event_queue=queue) self._subscriptions.append(sub) subscribe(player.avTransport, self.update_media) @@ -1147,52 +1198,53 @@ class SonosEntity(MediaPlayerDevice): @soco_error() @soco_coordinator - def set_sleep_timer(self, data): + def set_sleep_timer(self, sleep_time): """Set the timer on the player.""" - self.soco.set_sleep_timer(data[ATTR_SLEEP_TIME]) + self.soco.set_sleep_timer(sleep_time) @soco_error() @soco_coordinator - def clear_sleep_timer(self, data): + def clear_sleep_timer(self): """Clear the timer on the player.""" self.soco.set_sleep_timer(None) @soco_error() @soco_coordinator - def set_alarm(self, data): + def set_alarm( + self, alarm_id, time=None, volume=None, enabled=None, include_linked_zones=None + ): """Set the alarm clock on the player.""" - alarm = None for one_alarm in alarms.get_alarms(self.soco): # pylint: disable=protected-access - if one_alarm._alarm_id == str(data[ATTR_ALARM_ID]): + if one_alarm._alarm_id == str(alarm_id): alarm = one_alarm if alarm is None: - _LOGGER.warning("did not find alarm with id %s", data[ATTR_ALARM_ID]) + _LOGGER.warning("did not find alarm with id %s", alarm_id) return - if ATTR_TIME in data: - alarm.start_time = data[ATTR_TIME] - if ATTR_VOLUME in data: - alarm.volume = int(data[ATTR_VOLUME] * 100) - if ATTR_ENABLED in data: - alarm.enabled = data[ATTR_ENABLED] - if ATTR_INCLUDE_LINKED_ZONES in data: - alarm.include_linked_zones = data[ATTR_INCLUDE_LINKED_ZONES] + if time is not None: + alarm.start_time = time + if volume is not None: + alarm.volume = int(volume * 100) + if enabled is not None: + alarm.enabled = enabled + if include_linked_zones is not None: + alarm.include_linked_zones = include_linked_zones alarm.save() @soco_error() - def set_option(self, data): + def set_option(self, night_sound=None, speech_enhance=None): """Modify playback options.""" - if ATTR_NIGHT_SOUND in data and self._night_sound is not None: - self.soco.night_mode = data[ATTR_NIGHT_SOUND] + if night_sound is not None and self._night_sound is not None: + self.soco.night_mode = night_sound - if ATTR_SPEECH_ENHANCE in data and self._speech_enhance is not None: - self.soco.dialog_mode = data[ATTR_SPEECH_ENHANCE] + if speech_enhance is not None and self._speech_enhance is not None: + self.soco.dialog_mode = speech_enhance @soco_error() - def play_queue(self, data): + def play_queue(self, queue_position=0): """Start playing the queue.""" - self.soco.play_from_queue(data[ATTR_QUEUE_POSITION]) + self.soco.play_from_queue(queue_position) @property def device_state_attributes(self): diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 404fd4ed46d..733cb22b3b2 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -6,17 +6,15 @@ import logging from homeassistant import config as conf_util from homeassistant.config_entries import ConfigEntry -from homeassistant.const import ( - ATTR_ENTITY_ID, - CONF_ENTITY_NAMESPACE, - CONF_SCAN_INTERVAL, - ENTITY_MATCH_ALL, -) +from homeassistant.const import CONF_ENTITY_NAMESPACE, CONF_SCAN_INTERVAL from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import config_per_platform, discovery -from homeassistant.helpers.config_validation import make_entity_service_schema -from homeassistant.helpers.service import async_extract_entity_ids +from homeassistant.helpers import ( + config_per_platform, + config_validation as cv, + discovery, + service, +) from homeassistant.loader import async_get_integration, bind_hass from homeassistant.setup import async_prepare_setup_platform @@ -166,39 +164,27 @@ class EntityComponent: await platform.async_reset() return True - async def async_extract_from_service(self, service, expand_group=True): + async def async_extract_from_service(self, service_call, expand_group=True): """Extract all known and available entities from a service call. Will return an empty list if entities specified but unknown. This method must be run in the event loop. """ - data_ent_id = service.data.get(ATTR_ENTITY_ID) - - if data_ent_id is None: - return [] - - if data_ent_id == ENTITY_MATCH_ALL: - return [entity for entity in self.entities if entity.available] - - entity_ids = await async_extract_entity_ids(self.hass, service, expand_group) - return [ - entity - for entity in self.entities - if entity.available and entity.entity_id in entity_ids - ] + return await service.async_extract_entities( + self.hass, self.entities, service_call, expand_group + ) @callback def async_register_entity_service(self, name, schema, func, required_features=None): """Register an entity service.""" if isinstance(schema, dict): - schema = make_entity_service_schema(schema) + schema = cv.make_entity_service_schema(schema) async def handle_service(call): """Handle the service.""" - service_name = f"{self.domain}.{name}" await self.hass.helpers.service.entity_service_call( - self._platforms.values(), func, call, service_name, required_features + self._platforms.values(), func, call, required_features ) self.hass.services.async_register(self.domain, name, handle_service, schema) diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 0e4d80ac080..8fedc198fe2 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -7,6 +7,7 @@ from typing import Optional from homeassistant.const import DEVICE_DEFAULT_NAME from homeassistant.core import callback, split_entity_id, valid_entity_id from homeassistant.exceptions import HomeAssistantError, PlatformNotReady +from homeassistant.helpers import config_validation as cv, service from homeassistant.util.async_ import run_callback_threadsafe from .entity_registry import DISABLED_INTEGRATION @@ -194,7 +195,11 @@ class EntityPlatform: ) return False except Exception: # pylint: disable=broad-except - logger.exception("Error while setting up platform %s", self.platform_name) + logger.exception( + "Error while setting up %s platform for %s", + self.platform_name, + self.domain, + ) return False finally: warn_task.cancel() @@ -449,6 +454,33 @@ class EntityPlatform: self._async_unsub_polling() self._async_unsub_polling = None + async def async_extract_from_service(self, service_call, expand_group=True): + """Extract all known and available entities from a service call. + + Will return an empty list if entities specified but unknown. + + This method must be run in the event loop. + """ + return await service.async_extract_entities( + self.hass, self.entities.values(), service_call, expand_group + ) + + @callback + def async_register_entity_service(self, name, schema, func, required_features=None): + """Register an entity service.""" + if isinstance(schema, dict): + schema = cv.make_entity_service_schema(schema) + + async def handle_service(call): + """Handle the service.""" + await service.entity_service_call( + self.hass, [self], func, call, required_features + ) + + self.hass.services.async_register( + self.platform_name, name, handle_service, schema + ) + async def _update_entity_states(self, now: datetime) -> None: """Update the states of all the polling entities. diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 16fabe251af..d621d4e6242 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -108,13 +108,31 @@ def extract_entity_ids(hass, service_call, expand_group=True): ).result() +@bind_hass +async def async_extract_entities(hass, entities, service_call, expand_group=True): + """Extract a list of entity objects from a service call. + + Will convert group entity ids to the entity ids it represents. + """ + data_ent_id = service_call.data.get(ATTR_ENTITY_ID) + + if data_ent_id == ENTITY_MATCH_ALL: + return [entity for entity in entities if entity.available] + + entity_ids = await async_extract_entity_ids(hass, service_call, expand_group) + + return [ + entity + for entity in entities + if entity.available and entity.entity_id in entity_ids + ] + + @bind_hass async def async_extract_entity_ids(hass, service_call, expand_group=True): """Extract a list of entity ids from a service call. Will convert group entity ids to the entity ids it represents. - - Async friendly. """ entity_ids = service_call.data.get(ATTR_ENTITY_ID) area_ids = service_call.data.get(ATTR_AREA_ID) @@ -244,9 +262,7 @@ def async_set_service_schema(hass, domain, service, schema): @bind_hass -async def entity_service_call( - hass, platforms, func, call, service_name="", required_features=None -): +async def entity_service_call(hass, platforms, func, call, required_features=None): """Handle an entity service call. Calls all platforms simultaneously. diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index b42b30a836a..c80b6eac193 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -23,6 +23,7 @@ import homeassistant.helpers.config_validation as cv from homeassistant.setup import async_setup_component from tests.common import ( + MockEntity, get_test_home_assistant, mock_coro, mock_device_registry, @@ -64,6 +65,54 @@ def mock_entities(): return entities +@pytest.fixture +def area_mock(hass): + """Mock including area info.""" + hass.states.async_set("light.Bowl", STATE_ON) + hass.states.async_set("light.Ceiling", STATE_OFF) + hass.states.async_set("light.Kitchen", STATE_OFF) + + device_in_area = dev_reg.DeviceEntry(area_id="test-area") + device_no_area = dev_reg.DeviceEntry() + device_diff_area = dev_reg.DeviceEntry(area_id="diff-area") + + mock_device_registry( + hass, + { + device_in_area.id: device_in_area, + device_no_area.id: device_no_area, + device_diff_area.id: device_diff_area, + }, + ) + + entity_in_area = ent_reg.RegistryEntry( + entity_id="light.in_area", + unique_id="in-area-id", + platform="test", + device_id=device_in_area.id, + ) + entity_no_area = ent_reg.RegistryEntry( + entity_id="light.no_area", + unique_id="no-area-id", + platform="test", + device_id=device_no_area.id, + ) + entity_diff_area = ent_reg.RegistryEntry( + entity_id="light.diff_area", + unique_id="diff-area-id", + platform="test", + device_id=device_diff_area.id, + ) + mock_registry( + hass, + { + entity_in_area.entity_id: entity_in_area, + entity_no_area.entity_id: entity_no_area, + entity_diff_area.entity_id: entity_diff_area, + }, + ) + + class TestServiceHelpers(unittest.TestCase): """Test the Home Assistant service helpers.""" @@ -204,52 +253,8 @@ async def test_extract_entity_ids(hass): ) -async def test_extract_entity_ids_from_area(hass): +async def test_extract_entity_ids_from_area(hass, area_mock): """Test extract_entity_ids method with areas.""" - hass.states.async_set("light.Bowl", STATE_ON) - hass.states.async_set("light.Ceiling", STATE_OFF) - hass.states.async_set("light.Kitchen", STATE_OFF) - - device_in_area = dev_reg.DeviceEntry(area_id="test-area") - device_no_area = dev_reg.DeviceEntry() - device_diff_area = dev_reg.DeviceEntry(area_id="diff-area") - - mock_device_registry( - hass, - { - device_in_area.id: device_in_area, - device_no_area.id: device_no_area, - device_diff_area.id: device_diff_area, - }, - ) - - entity_in_area = ent_reg.RegistryEntry( - entity_id="light.in_area", - unique_id="in-area-id", - platform="test", - device_id=device_in_area.id, - ) - entity_no_area = ent_reg.RegistryEntry( - entity_id="light.no_area", - unique_id="no-area-id", - platform="test", - device_id=device_no_area.id, - ) - entity_diff_area = ent_reg.RegistryEntry( - entity_id="light.diff_area", - unique_id="diff-area-id", - platform="test", - device_id=device_diff_area.id, - ) - mock_registry( - hass, - { - entity_in_area.entity_id: entity_in_area, - entity_no_area.entity_id: entity_no_area, - entity_diff_area.entity_id: entity_diff_area, - }, - ) - call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"}) assert {"light.in_area"} == await service.async_extract_entity_ids(hass, call) @@ -678,3 +683,86 @@ async def test_domain_control_no_user(hass, mock_entities): ) assert len(calls) == 1 + + +async def test_extract_from_service_available_device(hass): + """Test the extraction of entity from service and device is available.""" + entities = [ + MockEntity(name="test_1", entity_id="test_domain.test_1"), + MockEntity(name="test_2", entity_id="test_domain.test_2", available=False), + MockEntity(name="test_3", entity_id="test_domain.test_3"), + MockEntity(name="test_4", entity_id="test_domain.test_4", available=False), + ] + + call_1 = ha.ServiceCall("test", "service", data={"entity_id": ENTITY_MATCH_ALL}) + + assert ["test_domain.test_1", "test_domain.test_3"] == [ + ent.entity_id + for ent in (await service.async_extract_entities(hass, entities, call_1)) + ] + + call_2 = ha.ServiceCall( + "test", + "service", + data={"entity_id": ["test_domain.test_3", "test_domain.test_4"]}, + ) + + assert ["test_domain.test_3"] == [ + ent.entity_id + for ent in (await service.async_extract_entities(hass, entities, call_2)) + ] + + +async def test_extract_from_service_empty_if_no_entity_id(hass): + """Test the extraction from service without specifying entity.""" + entities = [ + MockEntity(name="test_1", entity_id="test_domain.test_1"), + MockEntity(name="test_2", entity_id="test_domain.test_2"), + ] + call = ha.ServiceCall("test", "service") + + assert [] == [ + ent.entity_id + for ent in (await service.async_extract_entities(hass, entities, call)) + ] + + +async def test_extract_from_service_filter_out_non_existing_entities(hass): + """Test the extraction of non existing entities from service.""" + entities = [ + MockEntity(name="test_1", entity_id="test_domain.test_1"), + MockEntity(name="test_2", entity_id="test_domain.test_2"), + ] + + call = ha.ServiceCall( + "test", + "service", + {"entity_id": ["test_domain.test_2", "test_domain.non_exist"]}, + ) + + assert ["test_domain.test_2"] == [ + ent.entity_id + for ent in (await service.async_extract_entities(hass, entities, call)) + ] + + +async def test_extract_from_service_area_id(hass, area_mock): + """Test the extraction using area ID as reference.""" + entities = [ + MockEntity(name="in_area", entity_id="light.in_area"), + MockEntity(name="no_area", entity_id="light.no_area"), + MockEntity(name="diff_area", entity_id="light.diff_area"), + ] + + call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"}) + extracted = await service.async_extract_entities(hass, entities, call) + assert len(extracted) == 1 + assert extracted[0].entity_id == "light.in_area" + + call = ha.ServiceCall("light", "turn_on", {"area_id": ["test-area", "diff-area"]}) + extracted = await service.async_extract_entities(hass, entities, call) + assert len(extracted) == 2 + assert sorted(ent.entity_id for ent in extracted) == [ + "light.diff_area", + "light.in_area", + ]