Refactor entity service calls to reduce complexity (#99783)
* Refactor entity service calls to reduce complexity gets rid of the noqa C901 * Refactor entity service calls to reduce complexity gets rid of the noqa C901 * shortpull/99955/head
parent
16f7bc7bf8
commit
3d403c9b60
|
@ -732,8 +732,59 @@ def async_set_service_schema(
|
||||||
descriptions_cache[(domain, service)] = description
|
descriptions_cache[(domain, service)] = description
|
||||||
|
|
||||||
|
|
||||||
|
def _get_permissible_entity_candidates(
|
||||||
|
call: ServiceCall,
|
||||||
|
platforms: Iterable[EntityPlatform],
|
||||||
|
entity_perms: None | (Callable[[str, str], bool]),
|
||||||
|
target_all_entities: bool,
|
||||||
|
all_referenced: set[str] | None,
|
||||||
|
) -> list[Entity]:
|
||||||
|
"""Get entity candidates that the user is allowed to access."""
|
||||||
|
if entity_perms is not None:
|
||||||
|
# Check the permissions since entity_perms is set
|
||||||
|
if target_all_entities:
|
||||||
|
# If we target all entities, we will select all entities the user
|
||||||
|
# is allowed to control.
|
||||||
|
return [
|
||||||
|
entity
|
||||||
|
for platform in platforms
|
||||||
|
for entity in platform.entities.values()
|
||||||
|
if entity_perms(entity.entity_id, POLICY_CONTROL)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert all_referenced is not None
|
||||||
|
# If they reference specific entities, we will check if they are all
|
||||||
|
# allowed to be controlled.
|
||||||
|
for entity_id in all_referenced:
|
||||||
|
if not entity_perms(entity_id, POLICY_CONTROL):
|
||||||
|
raise Unauthorized(
|
||||||
|
context=call.context,
|
||||||
|
entity_id=entity_id,
|
||||||
|
permission=POLICY_CONTROL,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif target_all_entities:
|
||||||
|
return [
|
||||||
|
entity for platform in platforms for entity in platform.entities.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
# We have already validated they have permissions to control all_referenced
|
||||||
|
# entities so we do not need to check again.
|
||||||
|
assert all_referenced is not None
|
||||||
|
if single_entity := len(all_referenced) == 1 and list(all_referenced)[0]:
|
||||||
|
for platform in platforms:
|
||||||
|
if (entity := platform.entities.get(single_entity)) is not None:
|
||||||
|
return [entity]
|
||||||
|
|
||||||
|
return [
|
||||||
|
platform.entities[entity_id]
|
||||||
|
for platform in platforms
|
||||||
|
for entity_id in all_referenced.intersection(platform.entities)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
async def entity_service_call( # noqa: C901
|
async def entity_service_call(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
platforms: Iterable[EntityPlatform],
|
platforms: Iterable[EntityPlatform],
|
||||||
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
|
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
|
||||||
|
@ -771,69 +822,24 @@ async def entity_service_call( # noqa: C901
|
||||||
else:
|
else:
|
||||||
data = call
|
data = call
|
||||||
|
|
||||||
# Check the permissions
|
|
||||||
|
|
||||||
# A list with entities to call the service on.
|
# A list with entities to call the service on.
|
||||||
entity_candidates: list[Entity] = []
|
entity_candidates = _get_permissible_entity_candidates(
|
||||||
|
call,
|
||||||
if entity_perms is None:
|
platforms,
|
||||||
for platform in platforms:
|
entity_perms,
|
||||||
platform_entities = platform.entities
|
target_all_entities,
|
||||||
if target_all_entities:
|
all_referenced,
|
||||||
entity_candidates.extend(platform_entities.values())
|
)
|
||||||
else:
|
|
||||||
assert all_referenced is not None
|
|
||||||
entity_candidates.extend(
|
|
||||||
[
|
|
||||||
platform_entities[entity_id]
|
|
||||||
for entity_id in all_referenced.intersection(platform_entities)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
elif target_all_entities:
|
|
||||||
# If we target all entities, we will select all entities the user
|
|
||||||
# is allowed to control.
|
|
||||||
for platform in platforms:
|
|
||||||
entity_candidates.extend(
|
|
||||||
[
|
|
||||||
entity
|
|
||||||
for entity in platform.entities.values()
|
|
||||||
if entity_perms(entity.entity_id, POLICY_CONTROL)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert all_referenced is not None
|
|
||||||
|
|
||||||
for platform in platforms:
|
|
||||||
platform_entities = platform.entities
|
|
||||||
platform_entity_candidates = []
|
|
||||||
entity_id_matches = all_referenced.intersection(platform_entities)
|
|
||||||
for entity_id in entity_id_matches:
|
|
||||||
if not entity_perms(entity_id, POLICY_CONTROL):
|
|
||||||
raise Unauthorized(
|
|
||||||
context=call.context,
|
|
||||||
entity_id=entity_id,
|
|
||||||
permission=POLICY_CONTROL,
|
|
||||||
)
|
|
||||||
|
|
||||||
platform_entity_candidates.append(platform_entities[entity_id])
|
|
||||||
|
|
||||||
entity_candidates.extend(platform_entity_candidates)
|
|
||||||
|
|
||||||
if not target_all_entities:
|
if not target_all_entities:
|
||||||
assert referenced is not None
|
assert referenced is not None
|
||||||
|
|
||||||
# Only report on explicit referenced entities
|
# Only report on explicit referenced entities
|
||||||
missing = set(referenced.referenced)
|
missing = referenced.referenced.copy()
|
||||||
|
|
||||||
for entity in entity_candidates:
|
for entity in entity_candidates:
|
||||||
missing.discard(entity.entity_id)
|
missing.discard(entity.entity_id)
|
||||||
|
|
||||||
referenced.log_missing(missing)
|
referenced.log_missing(missing)
|
||||||
|
|
||||||
entities: list[Entity] = []
|
entities: list[Entity] = []
|
||||||
|
|
||||||
for entity in entity_candidates:
|
for entity in entity_candidates:
|
||||||
if not entity.available:
|
if not entity.available:
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in New Issue