Speed up fetching states by domain (#99467)

pull/92668/head
J. Nick Koston 2023-09-03 09:30:39 -05:00 committed by GitHub
parent b752419f25
commit 186e796e25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 19 deletions

View File

@ -1261,7 +1261,7 @@ class State:
"State max length is 255 characters." "State max length is 255 characters."
) )
self.entity_id = entity_id.lower() self.entity_id = entity_id
self.state = state self.state = state
self.attributes = ReadOnlyDict(attributes or {}) self.attributes = ReadOnlyDict(attributes or {})
self.last_updated = last_updated or dt_util.utcnow() self.last_updated = last_updated or dt_util.utcnow()
@ -1412,11 +1412,12 @@ class State:
class StateMachine: class StateMachine:
"""Helper class that tracks the state of different entities.""" """Helper class that tracks the state of different entities."""
__slots__ = ("_states", "_reservations", "_bus", "_loop") __slots__ = ("_states", "_domain_index", "_reservations", "_bus", "_loop")
def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None: def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None:
"""Initialize state machine.""" """Initialize state machine."""
self._states: dict[str, State] = {} self._states: dict[str, State] = {}
self._domain_index: dict[str, dict[str, State]] = {}
self._reservations: set[str] = set() self._reservations: set[str] = set()
self._bus = bus self._bus = bus
self._loop = loop self._loop = loop
@ -1440,13 +1441,13 @@ class StateMachine:
return list(self._states) return list(self._states)
if isinstance(domain_filter, str): if isinstance(domain_filter, str):
domain_filter = (domain_filter.lower(),) return list(self._domain_index.get(domain_filter.lower(), ()))
return [ states: list[str] = []
state.entity_id for domain in domain_filter:
for state in self._states.values() if domain_index := self._domain_index.get(domain):
if state.domain in domain_filter states.extend(domain_index)
] return states
@callback @callback
def async_entity_ids_count( def async_entity_ids_count(
@ -1460,11 +1461,9 @@ class StateMachine:
return len(self._states) return len(self._states)
if isinstance(domain_filter, str): if isinstance(domain_filter, str):
domain_filter = (domain_filter.lower(),) return len(self._domain_index.get(domain_filter.lower(), ()))
return len( return sum(len(self._domain_index.get(domain, ())) for domain in domain_filter)
[None for state in self._states.values() if state.domain in domain_filter]
)
def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]: def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]:
"""Create a list of all states.""" """Create a list of all states."""
@ -1484,11 +1483,13 @@ class StateMachine:
return list(self._states.values()) return list(self._states.values())
if isinstance(domain_filter, str): if isinstance(domain_filter, str):
domain_filter = (domain_filter.lower(),) return list(self._domain_index.get(domain_filter.lower(), {}).values())
return [ states: list[State] = []
state for state in self._states.values() if state.domain in domain_filter for domain in domain_filter:
] if domain_index := self._domain_index.get(domain):
states.extend(domain_index.values())
return states
def get(self, entity_id: str) -> State | None: def get(self, entity_id: str) -> State | None:
"""Retrieve state of entity_id or None if not found. """Retrieve state of entity_id or None if not found.
@ -1524,13 +1525,12 @@ class StateMachine:
""" """
entity_id = entity_id.lower() entity_id = entity_id.lower()
old_state = self._states.pop(entity_id, None) old_state = self._states.pop(entity_id, None)
self._reservations.discard(entity_id)
if entity_id in self._reservations:
self._reservations.remove(entity_id)
if old_state is None: if old_state is None:
return False return False
self._domain_index[old_state.domain].pop(entity_id)
old_state.expire() old_state.expire()
self._bus.async_fire( self._bus.async_fire(
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
@ -1652,6 +1652,10 @@ class StateMachine:
if old_state is not None: if old_state is not None:
old_state.expire() old_state.expire()
self._states[entity_id] = state self._states[entity_id] = state
if not (domain_index := self._domain_index.get(state.domain)):
domain_index = {}
self._domain_index[state.domain] = domain_index
domain_index[entity_id] = state
self._bus.async_fire( self._bus.async_fire(
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
{"entity_id": entity_id, "old_state": old_state, "new_state": state}, {"entity_id": entity_id, "old_state": old_state, "new_state": state},

View File

@ -1938,6 +1938,7 @@ async def test_async_entity_ids_count(hass: HomeAssistant) -> None:
assert hass.states.async_entity_ids_count() == 5 assert hass.states.async_entity_ids_count() == 5
assert hass.states.async_entity_ids_count("light") == 3 assert hass.states.async_entity_ids_count("light") == 3
assert hass.states.async_entity_ids_count({"light", "vacuum"}) == 4
async def test_hassjob_forbid_coroutine() -> None: async def test_hassjob_forbid_coroutine() -> None: