Speed up fetching states by domain (#99467)
parent
b752419f25
commit
186e796e25
|
@ -1261,7 +1261,7 @@ class State:
|
||||||
"State max length is 255 characters."
|
"State max length is 255 characters."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.entity_id = entity_id.lower()
|
self.entity_id = entity_id
|
||||||
self.state = state
|
self.state = state
|
||||||
self.attributes = ReadOnlyDict(attributes or {})
|
self.attributes = ReadOnlyDict(attributes or {})
|
||||||
self.last_updated = last_updated or dt_util.utcnow()
|
self.last_updated = last_updated or dt_util.utcnow()
|
||||||
|
@ -1412,11 +1412,12 @@ class State:
|
||||||
class StateMachine:
|
class StateMachine:
|
||||||
"""Helper class that tracks the state of different entities."""
|
"""Helper class that tracks the state of different entities."""
|
||||||
|
|
||||||
__slots__ = ("_states", "_reservations", "_bus", "_loop")
|
__slots__ = ("_states", "_domain_index", "_reservations", "_bus", "_loop")
|
||||||
|
|
||||||
def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None:
|
def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None:
|
||||||
"""Initialize state machine."""
|
"""Initialize state machine."""
|
||||||
self._states: dict[str, State] = {}
|
self._states: dict[str, State] = {}
|
||||||
|
self._domain_index: dict[str, dict[str, State]] = {}
|
||||||
self._reservations: set[str] = set()
|
self._reservations: set[str] = set()
|
||||||
self._bus = bus
|
self._bus = bus
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
|
@ -1440,13 +1441,13 @@ class StateMachine:
|
||||||
return list(self._states)
|
return list(self._states)
|
||||||
|
|
||||||
if isinstance(domain_filter, str):
|
if isinstance(domain_filter, str):
|
||||||
domain_filter = (domain_filter.lower(),)
|
return list(self._domain_index.get(domain_filter.lower(), ()))
|
||||||
|
|
||||||
return [
|
states: list[str] = []
|
||||||
state.entity_id
|
for domain in domain_filter:
|
||||||
for state in self._states.values()
|
if domain_index := self._domain_index.get(domain):
|
||||||
if state.domain in domain_filter
|
states.extend(domain_index)
|
||||||
]
|
return states
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_entity_ids_count(
|
def async_entity_ids_count(
|
||||||
|
@ -1460,11 +1461,9 @@ class StateMachine:
|
||||||
return len(self._states)
|
return len(self._states)
|
||||||
|
|
||||||
if isinstance(domain_filter, str):
|
if isinstance(domain_filter, str):
|
||||||
domain_filter = (domain_filter.lower(),)
|
return len(self._domain_index.get(domain_filter.lower(), ()))
|
||||||
|
|
||||||
return len(
|
return sum(len(self._domain_index.get(domain, ())) for domain in domain_filter)
|
||||||
[None for state in self._states.values() if state.domain in domain_filter]
|
|
||||||
)
|
|
||||||
|
|
||||||
def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]:
|
def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]:
|
||||||
"""Create a list of all states."""
|
"""Create a list of all states."""
|
||||||
|
@ -1484,11 +1483,13 @@ class StateMachine:
|
||||||
return list(self._states.values())
|
return list(self._states.values())
|
||||||
|
|
||||||
if isinstance(domain_filter, str):
|
if isinstance(domain_filter, str):
|
||||||
domain_filter = (domain_filter.lower(),)
|
return list(self._domain_index.get(domain_filter.lower(), {}).values())
|
||||||
|
|
||||||
return [
|
states: list[State] = []
|
||||||
state for state in self._states.values() if state.domain in domain_filter
|
for domain in domain_filter:
|
||||||
]
|
if domain_index := self._domain_index.get(domain):
|
||||||
|
states.extend(domain_index.values())
|
||||||
|
return states
|
||||||
|
|
||||||
def get(self, entity_id: str) -> State | None:
|
def get(self, entity_id: str) -> State | None:
|
||||||
"""Retrieve state of entity_id or None if not found.
|
"""Retrieve state of entity_id or None if not found.
|
||||||
|
@ -1524,13 +1525,12 @@ class StateMachine:
|
||||||
"""
|
"""
|
||||||
entity_id = entity_id.lower()
|
entity_id = entity_id.lower()
|
||||||
old_state = self._states.pop(entity_id, None)
|
old_state = self._states.pop(entity_id, None)
|
||||||
|
self._reservations.discard(entity_id)
|
||||||
if entity_id in self._reservations:
|
|
||||||
self._reservations.remove(entity_id)
|
|
||||||
|
|
||||||
if old_state is None:
|
if old_state is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
self._domain_index[old_state.domain].pop(entity_id)
|
||||||
old_state.expire()
|
old_state.expire()
|
||||||
self._bus.async_fire(
|
self._bus.async_fire(
|
||||||
EVENT_STATE_CHANGED,
|
EVENT_STATE_CHANGED,
|
||||||
|
@ -1652,6 +1652,10 @@ class StateMachine:
|
||||||
if old_state is not None:
|
if old_state is not None:
|
||||||
old_state.expire()
|
old_state.expire()
|
||||||
self._states[entity_id] = state
|
self._states[entity_id] = state
|
||||||
|
if not (domain_index := self._domain_index.get(state.domain)):
|
||||||
|
domain_index = {}
|
||||||
|
self._domain_index[state.domain] = domain_index
|
||||||
|
domain_index[entity_id] = state
|
||||||
self._bus.async_fire(
|
self._bus.async_fire(
|
||||||
EVENT_STATE_CHANGED,
|
EVENT_STATE_CHANGED,
|
||||||
{"entity_id": entity_id, "old_state": old_state, "new_state": state},
|
{"entity_id": entity_id, "old_state": old_state, "new_state": state},
|
||||||
|
|
|
@ -1938,6 +1938,7 @@ async def test_async_entity_ids_count(hass: HomeAssistant) -> None:
|
||||||
|
|
||||||
assert hass.states.async_entity_ids_count() == 5
|
assert hass.states.async_entity_ids_count() == 5
|
||||||
assert hass.states.async_entity_ids_count("light") == 3
|
assert hass.states.async_entity_ids_count("light") == 3
|
||||||
|
assert hass.states.async_entity_ids_count({"light", "vacuum"}) == 4
|
||||||
|
|
||||||
|
|
||||||
async def test_hassjob_forbid_coroutine() -> None:
|
async def test_hassjob_forbid_coroutine() -> None:
|
||||||
|
|
Loading…
Reference in New Issue