From 3d403c9b6020afcce6d625a945fd5f36486c820d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 8 Sep 2023 12:04:53 -0500 Subject: [PATCH] Refactor entity service calls to reduce complexity (#99783) * Refactor entity service calls to reduce complexity gets rid of the noqa C901 * Refactor entity service calls to reduce complexity gets rid of the noqa C901 * short --- homeassistant/helpers/service.py | 114 ++++++++++++++++--------------- 1 file changed, 60 insertions(+), 54 deletions(-) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 3eb537f9649..a0fe24cb656 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -732,8 +732,59 @@ def async_set_service_schema( descriptions_cache[(domain, service)] = description +def _get_permissible_entity_candidates( + call: ServiceCall, + platforms: Iterable[EntityPlatform], + entity_perms: None | (Callable[[str, str], bool]), + target_all_entities: bool, + all_referenced: set[str] | None, +) -> list[Entity]: + """Get entity candidates that the user is allowed to access.""" + if entity_perms is not None: + # Check the permissions since entity_perms is set + if target_all_entities: + # If we target all entities, we will select all entities the user + # is allowed to control. + return [ + entity + for platform in platforms + for entity in platform.entities.values() + if entity_perms(entity.entity_id, POLICY_CONTROL) + ] + + assert all_referenced is not None + # If they reference specific entities, we will check if they are all + # allowed to be controlled. + for entity_id in all_referenced: + if not entity_perms(entity_id, POLICY_CONTROL): + raise Unauthorized( + context=call.context, + entity_id=entity_id, + permission=POLICY_CONTROL, + ) + + elif target_all_entities: + return [ + entity for platform in platforms for entity in platform.entities.values() + ] + + # We have already validated they have permissions to control all_referenced + # entities so we do not need to check again. + assert all_referenced is not None + if single_entity := len(all_referenced) == 1 and list(all_referenced)[0]: + for platform in platforms: + if (entity := platform.entities.get(single_entity)) is not None: + return [entity] + + return [ + platform.entities[entity_id] + for platform in platforms + for entity_id in all_referenced.intersection(platform.entities) + ] + + @bind_hass -async def entity_service_call( # noqa: C901 +async def entity_service_call( hass: HomeAssistant, platforms: Iterable[EntityPlatform], func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]], @@ -771,69 +822,24 @@ async def entity_service_call( # noqa: C901 else: data = call - # Check the permissions - # A list with entities to call the service on. - entity_candidates: list[Entity] = [] - - if entity_perms is None: - for platform in platforms: - platform_entities = platform.entities - if target_all_entities: - entity_candidates.extend(platform_entities.values()) - else: - assert all_referenced is not None - entity_candidates.extend( - [ - platform_entities[entity_id] - for entity_id in all_referenced.intersection(platform_entities) - ] - ) - - elif target_all_entities: - # If we target all entities, we will select all entities the user - # is allowed to control. - for platform in platforms: - entity_candidates.extend( - [ - entity - for entity in platform.entities.values() - if entity_perms(entity.entity_id, POLICY_CONTROL) - ] - ) - - else: - assert all_referenced is not None - - for platform in platforms: - platform_entities = platform.entities - platform_entity_candidates = [] - entity_id_matches = all_referenced.intersection(platform_entities) - for entity_id in entity_id_matches: - if not entity_perms(entity_id, POLICY_CONTROL): - raise Unauthorized( - context=call.context, - entity_id=entity_id, - permission=POLICY_CONTROL, - ) - - platform_entity_candidates.append(platform_entities[entity_id]) - - entity_candidates.extend(platform_entity_candidates) + entity_candidates = _get_permissible_entity_candidates( + call, + platforms, + entity_perms, + target_all_entities, + all_referenced, + ) if not target_all_entities: assert referenced is not None - # Only report on explicit referenced entities - missing = set(referenced.referenced) - + missing = referenced.referenced.copy() for entity in entity_candidates: missing.discard(entity.entity_id) - referenced.log_missing(missing) entities: list[Entity] = [] - for entity in entity_candidates: if not entity.available: continue