Reduce overhead to call entity services (#106908)

pull/105955/head
J. Nick Koston 2024-01-07 22:30:52 -10:00 committed by GitHub
parent 9ad3c8dbc9
commit d260ed938a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 36 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Iterable
from datetime import timedelta
from functools import partial
from itertools import chain
import logging
from types import ModuleType
@ -20,8 +21,8 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP,
)
from homeassistant.core import (
EntityServiceResponse,
Event,
HassJob,
HomeAssistant,
ServiceCall,
ServiceResponse,
@ -225,13 +226,16 @@ class EntityComponent(Generic[_EntityT]):
if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
service_func: str | HassJob[..., Any]
service_func = func if isinstance(func, str) else HassJob(func)
async def handle_service(
call: ServiceCall,
) -> ServiceResponse:
"""Handle the service."""
result = await service.entity_service_call(
self.hass, self._entities, func, call, required_features
self.hass, self._entities, service_func, call, required_features
)
if result:
@ -259,16 +263,21 @@ class EntityComponent(Generic[_EntityT]):
if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
async def handle_service(
call: ServiceCall,
) -> EntityServiceResponse | None:
"""Handle the service."""
return await service.entity_service_call(
self.hass, self._entities, func, call, required_features
)
service_func: str | HassJob[..., Any]
service_func = func if isinstance(func, str) else HassJob(func)
self.hass.services.async_register(
self.domain, name, handle_service, schema, supports_response
self.domain,
name,
partial(
service.entity_service_call,
self.hass,
self._entities,
service_func,
required_features=required_features,
),
schema,
supports_response,
)
async def async_setup_platform(

View File

@ -5,6 +5,7 @@ import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Iterable
from contextvars import ContextVar
from datetime import datetime, timedelta
from functools import partial
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Protocol
@ -20,7 +21,7 @@ from homeassistant.core import (
CALLBACK_TYPE,
DOMAIN as HOMEASSISTANT_DOMAIN,
CoreState,
EntityServiceResponse,
HassJob,
HomeAssistant,
ServiceCall,
SupportsResponse,
@ -833,18 +834,21 @@ class EntityPlatform:
if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)
async def handle_service(call: ServiceCall) -> EntityServiceResponse | None:
"""Handle the service."""
return await service.entity_service_call(
self.hass,
self.domain_entities,
func,
call,
required_features,
)
service_func: str | HassJob[..., Any]
service_func = func if isinstance(func, str) else HassJob(func)
self.hass.services.async_register(
self.platform_name, name, handle_service, schema, supports_response
self.platform_name,
name,
partial(
service.entity_service_call,
self.hass,
self.domain_entities,
service_func,
required_features=required_features,
),
schema,
supports_response,
)
async def _update_entity_states(self, now: datetime) -> None:

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Iterable
from collections.abc import Awaitable, Callable, Iterable
import dataclasses
from enum import Enum
from functools import cache, partial, wraps
@ -29,6 +29,7 @@ from homeassistant.const import (
from homeassistant.core import (
Context,
EntityServiceResponse,
HassJob,
HomeAssistant,
ServiceCall,
ServiceResponse,
@ -191,11 +192,14 @@ class ServiceParams(TypedDict):
class ServiceTargetSelector:
"""Class to hold a target selector for a service."""
__slots__ = ("entity_ids", "device_ids", "area_ids")
def __init__(self, service_call: ServiceCall) -> None:
"""Extract ids from service call data."""
entity_ids: str | list | None = service_call.data.get(ATTR_ENTITY_ID)
device_ids: str | list | None = service_call.data.get(ATTR_DEVICE_ID)
area_ids: str | list | None = service_call.data.get(ATTR_AREA_ID)
service_call_data = service_call.data
entity_ids: str | list | None = service_call_data.get(ATTR_ENTITY_ID)
device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID)
area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID)
self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
@ -790,7 +794,7 @@ def _get_permissible_entity_candidates(
async def entity_service_call(
hass: HomeAssistant,
registered_entities: dict[str, Entity],
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
func: str | HassJob,
call: ServiceCall,
required_features: Iterable[int] | None = None,
) -> EntityServiceResponse | None:
@ -926,7 +930,7 @@ async def entity_service_call(
async def _handle_entity_call(
hass: HomeAssistant,
entity: Entity,
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
func: str | HassJob,
data: dict | ServiceCall,
context: Context,
) -> ServiceResponse:
@ -935,11 +939,11 @@ async def _handle_entity_call(
task: asyncio.Future[ServiceResponse] | None
if isinstance(func, str):
task = hass.async_run_job(
partial(getattr(entity, func), **data) # type: ignore[arg-type]
task = hass.async_run_hass_job(
HassJob(partial(getattr(entity, func), **data)) # type: ignore[arg-type]
)
else:
task = hass.async_run_job(func, entity, data)
task = hass.async_run_hass_job(func, entity, data)
# Guard because callback functions do not return a task when passed to
# async_run_job.

View File

@ -19,7 +19,13 @@ from homeassistant.const import (
STATE_ON,
EntityCategory,
)
from homeassistant.core import Context, HomeAssistant, ServiceCall, SupportsResponse
from homeassistant.core import (
Context,
HassJob,
HomeAssistant,
ServiceCall,
SupportsResponse,
)
from homeassistant.helpers import (
device_registry as dr,
entity_registry as er,
@ -803,7 +809,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
await service.entity_service_call(
hass,
mock_entities,
test_service_mock,
HassJob(test_service_mock),
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A],
)
@ -822,7 +828,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
await service.entity_service_call(
hass,
mock_entities,
test_service_mock,
HassJob(test_service_mock),
ServiceCall(
"test_domain", "test_service", {"entity_id": "light.living_room"}
),
@ -839,7 +845,7 @@ async def test_call_with_both_required_features(
await service.entity_service_call(
hass,
mock_entities,
test_service_mock,
HassJob(test_service_mock),
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A | SUPPORT_B],
)
@ -858,7 +864,7 @@ async def test_call_with_one_of_required_features(
await service.entity_service_call(
hass,
mock_entities,
test_service_mock,
HassJob(test_service_mock),
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A, SUPPORT_C],
)
@ -879,7 +885,7 @@ async def test_call_with_sync_func(hass: HomeAssistant, mock_entities) -> None:
await service.entity_service_call(
hass,
mock_entities,
test_service_mock,
HassJob(test_service_mock),
ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
)
assert test_service_mock.call_count == 1