Speed up fetching states by domain (#99467)
parent
b752419f25
commit
186e796e25
|
@ -1261,7 +1261,7 @@ class State:
|
|||
"State max length is 255 characters."
|
||||
)
|
||||
|
||||
self.entity_id = entity_id.lower()
|
||||
self.entity_id = entity_id
|
||||
self.state = state
|
||||
self.attributes = ReadOnlyDict(attributes or {})
|
||||
self.last_updated = last_updated or dt_util.utcnow()
|
||||
|
@ -1412,11 +1412,12 @@ class State:
|
|||
class StateMachine:
|
||||
"""Helper class that tracks the state of different entities."""
|
||||
|
||||
__slots__ = ("_states", "_reservations", "_bus", "_loop")
|
||||
__slots__ = ("_states", "_domain_index", "_reservations", "_bus", "_loop")
|
||||
|
||||
def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None:
|
||||
"""Initialize state machine."""
|
||||
self._states: dict[str, State] = {}
|
||||
self._domain_index: dict[str, dict[str, State]] = {}
|
||||
self._reservations: set[str] = set()
|
||||
self._bus = bus
|
||||
self._loop = loop
|
||||
|
@ -1440,13 +1441,13 @@ class StateMachine:
|
|||
return list(self._states)
|
||||
|
||||
if isinstance(domain_filter, str):
|
||||
domain_filter = (domain_filter.lower(),)
|
||||
return list(self._domain_index.get(domain_filter.lower(), ()))
|
||||
|
||||
return [
|
||||
state.entity_id
|
||||
for state in self._states.values()
|
||||
if state.domain in domain_filter
|
||||
]
|
||||
states: list[str] = []
|
||||
for domain in domain_filter:
|
||||
if domain_index := self._domain_index.get(domain):
|
||||
states.extend(domain_index)
|
||||
return states
|
||||
|
||||
@callback
|
||||
def async_entity_ids_count(
|
||||
|
@ -1460,11 +1461,9 @@ class StateMachine:
|
|||
return len(self._states)
|
||||
|
||||
if isinstance(domain_filter, str):
|
||||
domain_filter = (domain_filter.lower(),)
|
||||
return len(self._domain_index.get(domain_filter.lower(), ()))
|
||||
|
||||
return len(
|
||||
[None for state in self._states.values() if state.domain in domain_filter]
|
||||
)
|
||||
return sum(len(self._domain_index.get(domain, ())) for domain in domain_filter)
|
||||
|
||||
def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]:
|
||||
"""Create a list of all states."""
|
||||
|
@ -1484,11 +1483,13 @@ class StateMachine:
|
|||
return list(self._states.values())
|
||||
|
||||
if isinstance(domain_filter, str):
|
||||
domain_filter = (domain_filter.lower(),)
|
||||
return list(self._domain_index.get(domain_filter.lower(), {}).values())
|
||||
|
||||
return [
|
||||
state for state in self._states.values() if state.domain in domain_filter
|
||||
]
|
||||
states: list[State] = []
|
||||
for domain in domain_filter:
|
||||
if domain_index := self._domain_index.get(domain):
|
||||
states.extend(domain_index.values())
|
||||
return states
|
||||
|
||||
def get(self, entity_id: str) -> State | None:
|
||||
"""Retrieve state of entity_id or None if not found.
|
||||
|
@ -1524,13 +1525,12 @@ class StateMachine:
|
|||
"""
|
||||
entity_id = entity_id.lower()
|
||||
old_state = self._states.pop(entity_id, None)
|
||||
|
||||
if entity_id in self._reservations:
|
||||
self._reservations.remove(entity_id)
|
||||
self._reservations.discard(entity_id)
|
||||
|
||||
if old_state is None:
|
||||
return False
|
||||
|
||||
self._domain_index[old_state.domain].pop(entity_id)
|
||||
old_state.expire()
|
||||
self._bus.async_fire(
|
||||
EVENT_STATE_CHANGED,
|
||||
|
@ -1652,6 +1652,10 @@ class StateMachine:
|
|||
if old_state is not None:
|
||||
old_state.expire()
|
||||
self._states[entity_id] = state
|
||||
if not (domain_index := self._domain_index.get(state.domain)):
|
||||
domain_index = {}
|
||||
self._domain_index[state.domain] = domain_index
|
||||
domain_index[entity_id] = state
|
||||
self._bus.async_fire(
|
||||
EVENT_STATE_CHANGED,
|
||||
{"entity_id": entity_id, "old_state": old_state, "new_state": state},
|
||||
|
|
|
@ -1938,6 +1938,7 @@ async def test_async_entity_ids_count(hass: HomeAssistant) -> None:
|
|||
|
||||
assert hass.states.async_entity_ids_count() == 5
|
||||
assert hass.states.async_entity_ids_count("light") == 3
|
||||
assert hass.states.async_entity_ids_count({"light", "vacuum"}) == 4
|
||||
|
||||
|
||||
async def test_hassjob_forbid_coroutine() -> None:
|
||||
|
|
Loading…
Reference in New Issue