Add domain filter support to async_all to match async_entity_ids (#39725)
This avoids copying all the states before applying the filterpull/39728/head
parent
19818d96b7
commit
251d8919ea
|
@ -41,8 +41,7 @@ class HumidityHandler(intent.IntentHandler):
|
|||
hass = intent_obj.hass
|
||||
slots = self.async_validate_slots(intent_obj.slots)
|
||||
state = hass.helpers.intent.async_match_state(
|
||||
slots["name"]["value"],
|
||||
[state for state in hass.states.async_all() if state.domain == DOMAIN],
|
||||
slots["name"]["value"], hass.states.async_all(DOMAIN)
|
||||
)
|
||||
|
||||
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
||||
|
@ -87,7 +86,7 @@ class SetModeHandler(intent.IntentHandler):
|
|||
slots = self.async_validate_slots(intent_obj.slots)
|
||||
state = hass.helpers.intent.async_match_state(
|
||||
slots["name"]["value"],
|
||||
[state for state in hass.states.async_all() if state.domain == DOMAIN],
|
||||
hass.states.async_all(DOMAIN),
|
||||
)
|
||||
|
||||
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
||||
|
|
|
@ -39,8 +39,7 @@ class SetIntentHandler(intent.IntentHandler):
|
|||
hass = intent_obj.hass
|
||||
slots = self.async_validate_slots(intent_obj.slots)
|
||||
state = hass.helpers.intent.async_match_state(
|
||||
slots["name"]["value"],
|
||||
[state for state in hass.states.async_all() if state.domain == DOMAIN],
|
||||
slots["name"]["value"], hass.states.async_all(DOMAIN)
|
||||
)
|
||||
|
||||
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
||||
|
|
|
@ -183,10 +183,7 @@ async def handle_webhook(hass, webhook_id, request):
|
|||
|
||||
response = []
|
||||
|
||||
for person in hass.states.async_all():
|
||||
if person.domain != "person":
|
||||
continue
|
||||
|
||||
for person in hass.states.async_all("person"):
|
||||
if "latitude" in person.attributes and "longitude" in person.attributes:
|
||||
response.append(
|
||||
{
|
||||
|
|
|
@ -918,17 +918,29 @@ class StateMachine:
|
|||
if state.domain in domain_filter
|
||||
]
|
||||
|
||||
def all(self) -> List[State]:
|
||||
def all(self, domain_filter: Optional[Union[str, Iterable]] = None) -> List[State]:
|
||||
"""Create a list of all states."""
|
||||
return run_callback_threadsafe(self._loop, self.async_all).result()
|
||||
return run_callback_threadsafe(
|
||||
self._loop, self.async_all, domain_filter
|
||||
).result()
|
||||
|
||||
@callback
|
||||
def async_all(self) -> List[State]:
|
||||
"""Create a list of all states.
|
||||
def async_all(
|
||||
self, domain_filter: Optional[Union[str, Iterable]] = None
|
||||
) -> List[State]:
|
||||
"""Create a list of all states matching the filter.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
return list(self._states.values())
|
||||
if domain_filter is None:
|
||||
return list(self._states.values())
|
||||
|
||||
if isinstance(domain_filter, str):
|
||||
domain_filter = (domain_filter.lower(),)
|
||||
|
||||
return [
|
||||
state for state in self._states.values() if state.domain in domain_filter
|
||||
]
|
||||
|
||||
def get(self, entity_id: str) -> Optional[State]:
|
||||
"""Retrieve state of entity_id or None if not found.
|
||||
|
|
|
@ -459,8 +459,7 @@ class DomainStates:
|
|||
sorted(
|
||||
(
|
||||
_wrap_state(self._hass, state)
|
||||
for state in self._hass.states.async_all()
|
||||
if state.domain == self._domain
|
||||
for state in self._hass.states.async_all(self._domain)
|
||||
),
|
||||
key=lambda state: state.entity_id,
|
||||
)
|
||||
|
|
|
@ -1454,3 +1454,26 @@ async def test_chained_logging_misses_log_timeout(hass, caplog):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
assert "_task_chain_" not in caplog.text
|
||||
|
||||
|
||||
async def test_async_all(hass):
|
||||
"""Test async_all."""
|
||||
|
||||
hass.states.async_set("switch.link", "on")
|
||||
hass.states.async_set("light.bowl", "on")
|
||||
hass.states.async_set("light.frog", "on")
|
||||
hass.states.async_set("vacuum.floor", "on")
|
||||
|
||||
assert {state.entity_id for state in hass.states.async_all()} == {
|
||||
"switch.link",
|
||||
"light.bowl",
|
||||
"light.frog",
|
||||
"vacuum.floor",
|
||||
}
|
||||
assert {state.entity_id for state in hass.states.async_all("light")} == {
|
||||
"light.bowl",
|
||||
"light.frog",
|
||||
}
|
||||
assert {
|
||||
state.entity_id for state in hass.states.async_all(["light", "switch"])
|
||||
} == {"light.bowl", "light.frog", "switch.link"}
|
||||
|
|
Loading…
Reference in New Issue