From 8569ddc5f94803de631816226ea185d21891ec47 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 6 Feb 2024 12:41:57 -0600 Subject: [PATCH] Fix entity services targeting entities outside the platform when using areas/devices (#109810) --- homeassistant/helpers/entity_platform.py | 26 +++++++- tests/helpers/test_entity_platform.py | 82 ++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 3 deletions(-) diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 7cf7ab62495..db2760d554c 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -57,6 +57,7 @@ SLOW_ADD_MIN_TIMEOUT = 500 PLATFORM_NOT_READY_RETRIES = 10 DATA_ENTITY_PLATFORM = "entity_platform" DATA_DOMAIN_ENTITIES = "domain_entities" +DATA_DOMAIN_PLATFORM_ENTITIES = "domain_platform_entities" PLATFORM_NOT_READY_BASE_WAIT_TIME = 30 # seconds _LOGGER = getLogger(__name__) @@ -124,6 +125,8 @@ class EntityPlatform: self.scan_interval = scan_interval self.entity_namespace = entity_namespace self.config_entry: config_entries.ConfigEntry | None = None + # Storage for entities for this specific platform only + # which are indexed by entity_id self.entities: dict[str, Entity] = {} self.component_translations: dict[str, Any] = {} self.platform_translations: dict[str, Any] = {} @@ -145,9 +148,24 @@ class EntityPlatform: # which powers entity_component.add_entities self.parallel_updates_created = platform is None - self.domain_entities: dict[str, Entity] = hass.data.setdefault( + # Storage for entities indexed by domain + # with the child dict indexed by entity_id + # + # This is usually media_player, light, switch, etc. + domain_entities: dict[str, dict[str, Entity]] = hass.data.setdefault( DATA_DOMAIN_ENTITIES, {} - ).setdefault(domain, {}) + ) + self.domain_entities = domain_entities.setdefault(domain, {}) + + # Storage for entities indexed by domain and platform + # with the child dict indexed by entity_id + # + # This is usually media_player.yamaha, light.hue, switch.tplink, etc. + domain_platform_entities: dict[ + tuple[str, str], dict[str, Entity] + ] = hass.data.setdefault(DATA_DOMAIN_PLATFORM_ENTITIES, {}) + key = (domain, platform_name) + self.domain_platform_entities = domain_platform_entities.setdefault(key, {}) def __repr__(self) -> str: """Represent an EntityPlatform.""" @@ -743,6 +761,7 @@ class EntityPlatform: entity_id = entity.entity_id self.entities[entity_id] = entity self.domain_entities[entity_id] = entity + self.domain_platform_entities[entity_id] = entity if not restored: # Reserve the state in the state machine @@ -756,6 +775,7 @@ class EntityPlatform: """Remove entity from entities dict.""" self.entities.pop(entity_id) self.domain_entities.pop(entity_id) + self.domain_platform_entities.pop(entity_id) entity.async_on_remove(remove_entity_cb) @@ -852,7 +872,7 @@ class EntityPlatform: partial( service.entity_service_call, self.hass, - self.domain_entities, + self.domain_platform_entities, service_func, required_features=required_features, ), diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 01558c426c7..f16b5c16b5a 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -19,6 +19,7 @@ from homeassistant.core import ( ) from homeassistant.exceptions import HomeAssistantError, PlatformNotReady from homeassistant.helpers import ( + area_registry as ar, device_registry as dr, entity_platform, entity_registry as er, @@ -1628,6 +1629,87 @@ async def test_register_entity_service_response_data_multiple_matches_raises( ) +async def test_register_entity_service_limited_to_matching_platforms( + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + area_registry: ar.AreaRegistry, +) -> None: + """Test an entity services only targets entities for the platform and domain.""" + + mock_area = area_registry.async_get_or_create("mock_area") + + entity1_entry = entity_registry.async_get_or_create( + "base_platform", "mock_platform", "1234", suggested_object_id="entity1" + ) + entity_registry.async_update_entity(entity1_entry.entity_id, area_id=mock_area.id) + entity2_entry = entity_registry.async_get_or_create( + "base_platform", "mock_platform", "5678", suggested_object_id="entity2" + ) + entity_registry.async_update_entity(entity2_entry.entity_id, area_id=mock_area.id) + entity3_entry = entity_registry.async_get_or_create( + "base_platform", "other_mock_platform", "7891", suggested_object_id="entity3" + ) + entity_registry.async_update_entity(entity3_entry.entity_id, area_id=mock_area.id) + entity4_entry = entity_registry.async_get_or_create( + "base_platform", "other_mock_platform", "1433", suggested_object_id="entity4" + ) + entity_registry.async_update_entity(entity4_entry.entity_id, area_id=mock_area.id) + + async def generate_response( + target: MockEntity, call: ServiceCall + ) -> ServiceResponse: + assert call.return_response + return {"response-key": f"response-value-{target.entity_id}"} + + entity_platform = MockEntityPlatform( + hass, domain="base_platform", platform_name="mock_platform", platform=None + ) + entity1 = MockEntity( + entity_id=entity1_entry.entity_id, unique_id=entity1_entry.unique_id + ) + entity2 = MockEntity( + entity_id=entity2_entry.entity_id, unique_id=entity2_entry.unique_id + ) + await entity_platform.async_add_entities([entity1, entity2]) + + other_entity_platform = MockEntityPlatform( + hass, domain="base_platform", platform_name="other_mock_platform", platform=None + ) + entity3 = MockEntity( + entity_id=entity3_entry.entity_id, unique_id=entity3_entry.unique_id + ) + entity4 = MockEntity( + entity_id=entity4_entry.entity_id, unique_id=entity4_entry.unique_id + ) + await other_entity_platform.async_add_entities([entity3, entity4]) + + entity_platform.async_register_entity_service( + "hello", + {"some": str}, + generate_response, + supports_response=SupportsResponse.ONLY, + ) + + response_data = await hass.services.async_call( + "mock_platform", + "hello", + service_data={"some": "data"}, + target={"area_id": [mock_area.id]}, + blocking=True, + return_response=True, + ) + # We should not target entity3 and entity4 even though they are in the area + # because they are only part of the domain and not the platform + assert response_data == { + "base_platform.entity1": { + "response-key": "response-value-base_platform.entity1" + }, + "base_platform.entity2": { + "response-key": "response-value-base_platform.entity2" + }, + } + + async def test_invalid_entity_id(hass: HomeAssistant) -> None: """Test specifying an invalid entity id.""" platform = MockEntityPlatform(hass)