Reduce overhead to call entity services (#106908)
parent
9ad3c8dbc9
commit
d260ed938a
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from collections.abc import Callable, Iterable
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
import logging
|
||||
from types import ModuleType
|
||||
|
@ -20,8 +21,8 @@ from homeassistant.const import (
|
|||
EVENT_HOMEASSISTANT_STOP,
|
||||
)
|
||||
from homeassistant.core import (
|
||||
EntityServiceResponse,
|
||||
Event,
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
|
@ -225,13 +226,16 @@ class EntityComponent(Generic[_EntityT]):
|
|||
if isinstance(schema, dict):
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
|
||||
service_func: str | HassJob[..., Any]
|
||||
service_func = func if isinstance(func, str) else HassJob(func)
|
||||
|
||||
async def handle_service(
|
||||
call: ServiceCall,
|
||||
) -> ServiceResponse:
|
||||
"""Handle the service."""
|
||||
|
||||
result = await service.entity_service_call(
|
||||
self.hass, self._entities, func, call, required_features
|
||||
self.hass, self._entities, service_func, call, required_features
|
||||
)
|
||||
|
||||
if result:
|
||||
|
@ -259,16 +263,21 @@ class EntityComponent(Generic[_EntityT]):
|
|||
if isinstance(schema, dict):
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
|
||||
async def handle_service(
|
||||
call: ServiceCall,
|
||||
) -> EntityServiceResponse | None:
|
||||
"""Handle the service."""
|
||||
return await service.entity_service_call(
|
||||
self.hass, self._entities, func, call, required_features
|
||||
)
|
||||
service_func: str | HassJob[..., Any]
|
||||
service_func = func if isinstance(func, str) else HassJob(func)
|
||||
|
||||
self.hass.services.async_register(
|
||||
self.domain, name, handle_service, schema, supports_response
|
||||
self.domain,
|
||||
name,
|
||||
partial(
|
||||
service.entity_service_call,
|
||||
self.hass,
|
||||
self._entities,
|
||||
service_func,
|
||||
required_features=required_features,
|
||||
),
|
||||
schema,
|
||||
supports_response,
|
||||
)
|
||||
|
||||
async def async_setup_platform(
|
||||
|
|
|
@ -5,6 +5,7 @@ import asyncio
|
|||
from collections.abc import Awaitable, Callable, Coroutine, Iterable
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
from logging import Logger, getLogger
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
|
@ -20,7 +21,7 @@ from homeassistant.core import (
|
|||
CALLBACK_TYPE,
|
||||
DOMAIN as HOMEASSISTANT_DOMAIN,
|
||||
CoreState,
|
||||
EntityServiceResponse,
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
SupportsResponse,
|
||||
|
@ -833,18 +834,21 @@ class EntityPlatform:
|
|||
if isinstance(schema, dict):
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
|
||||
async def handle_service(call: ServiceCall) -> EntityServiceResponse | None:
|
||||
"""Handle the service."""
|
||||
return await service.entity_service_call(
|
||||
self.hass,
|
||||
self.domain_entities,
|
||||
func,
|
||||
call,
|
||||
required_features,
|
||||
)
|
||||
service_func: str | HassJob[..., Any]
|
||||
service_func = func if isinstance(func, str) else HassJob(func)
|
||||
|
||||
self.hass.services.async_register(
|
||||
self.platform_name, name, handle_service, schema, supports_response
|
||||
self.platform_name,
|
||||
name,
|
||||
partial(
|
||||
service.entity_service_call,
|
||||
self.hass,
|
||||
self.domain_entities,
|
||||
service_func,
|
||||
required_features=required_features,
|
||||
),
|
||||
schema,
|
||||
supports_response,
|
||||
)
|
||||
|
||||
async def _update_entity_states(self, now: datetime) -> None:
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable, Coroutine, Iterable
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
import dataclasses
|
||||
from enum import Enum
|
||||
from functools import cache, partial, wraps
|
||||
|
@ -29,6 +29,7 @@ from homeassistant.const import (
|
|||
from homeassistant.core import (
|
||||
Context,
|
||||
EntityServiceResponse,
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
|
@ -191,11 +192,14 @@ class ServiceParams(TypedDict):
|
|||
class ServiceTargetSelector:
|
||||
"""Class to hold a target selector for a service."""
|
||||
|
||||
__slots__ = ("entity_ids", "device_ids", "area_ids")
|
||||
|
||||
def __init__(self, service_call: ServiceCall) -> None:
|
||||
"""Extract ids from service call data."""
|
||||
entity_ids: str | list | None = service_call.data.get(ATTR_ENTITY_ID)
|
||||
device_ids: str | list | None = service_call.data.get(ATTR_DEVICE_ID)
|
||||
area_ids: str | list | None = service_call.data.get(ATTR_AREA_ID)
|
||||
service_call_data = service_call.data
|
||||
entity_ids: str | list | None = service_call_data.get(ATTR_ENTITY_ID)
|
||||
device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID)
|
||||
area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID)
|
||||
|
||||
self.entity_ids = (
|
||||
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
|
||||
|
@ -790,7 +794,7 @@ def _get_permissible_entity_candidates(
|
|||
async def entity_service_call(
|
||||
hass: HomeAssistant,
|
||||
registered_entities: dict[str, Entity],
|
||||
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
|
||||
func: str | HassJob,
|
||||
call: ServiceCall,
|
||||
required_features: Iterable[int] | None = None,
|
||||
) -> EntityServiceResponse | None:
|
||||
|
@ -926,7 +930,7 @@ async def entity_service_call(
|
|||
async def _handle_entity_call(
|
||||
hass: HomeAssistant,
|
||||
entity: Entity,
|
||||
func: str | Callable[..., Coroutine[Any, Any, ServiceResponse]],
|
||||
func: str | HassJob,
|
||||
data: dict | ServiceCall,
|
||||
context: Context,
|
||||
) -> ServiceResponse:
|
||||
|
@ -935,11 +939,11 @@ async def _handle_entity_call(
|
|||
|
||||
task: asyncio.Future[ServiceResponse] | None
|
||||
if isinstance(func, str):
|
||||
task = hass.async_run_job(
|
||||
partial(getattr(entity, func), **data) # type: ignore[arg-type]
|
||||
task = hass.async_run_hass_job(
|
||||
HassJob(partial(getattr(entity, func), **data)) # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
task = hass.async_run_job(func, entity, data)
|
||||
task = hass.async_run_hass_job(func, entity, data)
|
||||
|
||||
# Guard because callback functions do not return a task when passed to
|
||||
# async_run_job.
|
||||
|
|
|
@ -19,7 +19,13 @@ from homeassistant.const import (
|
|||
STATE_ON,
|
||||
EntityCategory,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant, ServiceCall, SupportsResponse
|
||||
from homeassistant.core import (
|
||||
Context,
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
SupportsResponse,
|
||||
)
|
||||
from homeassistant.helpers import (
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
|
@ -803,7 +809,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
|
|||
await service.entity_service_call(
|
||||
hass,
|
||||
mock_entities,
|
||||
test_service_mock,
|
||||
HassJob(test_service_mock),
|
||||
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
|
||||
required_features=[SUPPORT_A],
|
||||
)
|
||||
|
@ -822,7 +828,7 @@ async def test_call_with_required_features(hass: HomeAssistant, mock_entities) -
|
|||
await service.entity_service_call(
|
||||
hass,
|
||||
mock_entities,
|
||||
test_service_mock,
|
||||
HassJob(test_service_mock),
|
||||
ServiceCall(
|
||||
"test_domain", "test_service", {"entity_id": "light.living_room"}
|
||||
),
|
||||
|
@ -839,7 +845,7 @@ async def test_call_with_both_required_features(
|
|||
await service.entity_service_call(
|
||||
hass,
|
||||
mock_entities,
|
||||
test_service_mock,
|
||||
HassJob(test_service_mock),
|
||||
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
|
||||
required_features=[SUPPORT_A | SUPPORT_B],
|
||||
)
|
||||
|
@ -858,7 +864,7 @@ async def test_call_with_one_of_required_features(
|
|||
await service.entity_service_call(
|
||||
hass,
|
||||
mock_entities,
|
||||
test_service_mock,
|
||||
HassJob(test_service_mock),
|
||||
ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
|
||||
required_features=[SUPPORT_A, SUPPORT_C],
|
||||
)
|
||||
|
@ -879,7 +885,7 @@ async def test_call_with_sync_func(hass: HomeAssistant, mock_entities) -> None:
|
|||
await service.entity_service_call(
|
||||
hass,
|
||||
mock_entities,
|
||||
test_service_mock,
|
||||
HassJob(test_service_mock),
|
||||
ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
|
||||
)
|
||||
assert test_service_mock.call_count == 1
|
||||
|
|
Loading…
Reference in New Issue