Refactor entity service calls to reduce complexity (#99783)

* Refactor entity service calls to reduce complexity

gets rid of the noqa C901

* Refactor entity service calls to reduce complexity

gets rid of the noqa C901

* short
pull/99955/head
J. Nick Koston 2023-09-08 12:04:53 -05:00 committed by GitHub
parent 16f7bc7bf8
commit 3d403c9b60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 60 additions and 54 deletions

View File

@ -732,8 +732,59 @@ def async_set_service_schema(
descriptions_cache[(domain, service)] = description descriptions_cache[(domain, service)] = description
def _get_permissible_entity_candidates(
call: ServiceCall,
platforms: Iterable[EntityPlatform],
entity_perms: None | (Callable[[str, str], bool]),
target_all_entities: bool,
all_referenced: set[str] | None,
) -> list[Entity]:
"""Get entity candidates that the user is allowed to access."""
if entity_perms is not None:
# Check the permissions since entity_perms is set
if target_all_entities:
# If we target all entities, we will select all entities the user
# is allowed to control.
return [
entity
for platform in platforms
for entity in platform.entities.values()
if entity_perms(entity.entity_id, POLICY_CONTROL)
]
assert all_referenced is not None
# If they reference specific entities, we will check if they are all
# allowed to be controlled.
for entity_id in all_referenced:
if not entity_perms(entity_id, POLICY_CONTROL):
raise Unauthorized(
context=call.context,
entity_id=entity_id,
permission=POLICY_CONTROL,
)
elif target_all_entities:
return [
entity for platform in platforms for entity in platform.entities.values()
]
# We have already validated they have permissions to control all_referenced
# entities so we do not need to check again.
assert all_referenced is not None
if single_entity := len(all_referenced) == 1 and list(all_referenced)[0]:
for platform in platforms:
if (entity := platform.entities.get(single_entity)) is not None:
return [entity]
return [
platform.entities[entity_id]
for platform in platforms
for entity_id in all_referenced.intersection(platform.entities)
]
@bind_hass @bind_hass
async def entity_service_call( # noqa: C901 async def entity_service_call(
hass: HomeAssistant, hass: HomeAssistant,
platforms: Iterable[EntityPlatform], platforms: Iterable[EntityPlatform],
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]], func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
@ -771,69 +822,24 @@ async def entity_service_call( # noqa: C901
else: else:
data = call data = call
# Check the permissions
# A list with entities to call the service on. # A list with entities to call the service on.
entity_candidates: list[Entity] = [] entity_candidates = _get_permissible_entity_candidates(
call,
if entity_perms is None: platforms,
for platform in platforms: entity_perms,
platform_entities = platform.entities target_all_entities,
if target_all_entities: all_referenced,
entity_candidates.extend(platform_entities.values()) )
else:
assert all_referenced is not None
entity_candidates.extend(
[
platform_entities[entity_id]
for entity_id in all_referenced.intersection(platform_entities)
]
)
elif target_all_entities:
# If we target all entities, we will select all entities the user
# is allowed to control.
for platform in platforms:
entity_candidates.extend(
[
entity
for entity in platform.entities.values()
if entity_perms(entity.entity_id, POLICY_CONTROL)
]
)
else:
assert all_referenced is not None
for platform in platforms:
platform_entities = platform.entities
platform_entity_candidates = []
entity_id_matches = all_referenced.intersection(platform_entities)
for entity_id in entity_id_matches:
if not entity_perms(entity_id, POLICY_CONTROL):
raise Unauthorized(
context=call.context,
entity_id=entity_id,
permission=POLICY_CONTROL,
)
platform_entity_candidates.append(platform_entities[entity_id])
entity_candidates.extend(platform_entity_candidates)
if not target_all_entities: if not target_all_entities:
assert referenced is not None assert referenced is not None
# Only report on explicit referenced entities # Only report on explicit referenced entities
missing = set(referenced.referenced) missing = referenced.referenced.copy()
for entity in entity_candidates: for entity in entity_candidates:
missing.discard(entity.entity_id) missing.discard(entity.entity_id)
referenced.log_missing(missing) referenced.log_missing(missing)
entities: list[Entity] = [] entities: list[Entity] = []
for entity in entity_candidates: for entity in entity_candidates:
if not entity.available: if not entity.available:
continue continue