Avoid enumerating the whole state machine on api service calls (#103147)

pull/103166/head
J. Nick Koston 2023-11-01 04:25:02 -05:00 committed by GitHub
parent daee5baef6
commit 78e546b35a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 7 deletions

View File

@ -1,9 +1,11 @@
"""Rest API for Home Assistant."""
import asyncio
from asyncio import timeout
from collections.abc import Collection
from functools import lru_cache
from http import HTTPStatus
import logging
from typing import Any
from aiohttp import web
from aiohttp.web_exceptions import HTTPBadRequest
@ -16,6 +18,7 @@ from homeassistant.components.http import HomeAssistantView, require_admin
from homeassistant.const import (
CONTENT_TYPE_JSON,
EVENT_HOMEASSISTANT_STOP,
EVENT_STATE_CHANGED,
MATCH_ALL,
URL_API,
URL_API_COMPONENTS,
@ -38,10 +41,12 @@ from homeassistant.exceptions import (
Unauthorized,
)
from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.event import EventStateChangedData
from homeassistant.helpers.json import json_dumps
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, EventType
from homeassistant.util.json import json_loads
from homeassistant.util.read_only_dict import ReadOnlyDict
_LOGGER = logging.getLogger(__name__)
@ -369,6 +374,18 @@ class APIDomainServicesView(HomeAssistantView):
)
context = self.context(request)
changed_states: list[ReadOnlyDict[str, Collection[Any]]] = []
@ha.callback
def _async_save_changed_entities(
event: EventType[EventStateChangedData],
) -> None:
if event.context == context and (state := event.data["new_state"]):
changed_states.append(state.as_dict())
cancel_listen = hass.bus.async_listen(
EVENT_STATE_CHANGED, _async_save_changed_entities, run_immediately=True
)
try:
await hass.services.async_call(
@ -376,12 +393,8 @@ class APIDomainServicesView(HomeAssistantView):
)
except (vol.Invalid, ServiceNotFound) as ex:
raise HTTPBadRequest() from ex
changed_states = []
for state in hass.states.async_all():
if state.context is context:
changed_states.append(state)
finally:
cancel_listen()
return self.json(changed_states)