Add domain filter support to async_all to match async_entity_ids (#39725)

This avoids copying all the states before applying
the filter
pull/39728/head
J. Nick Koston 2020-09-06 16:20:32 -05:00 committed by GitHub
parent 19818d96b7
commit 251d8919ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 45 additions and 16 deletions

View File

@ -41,8 +41,7 @@ class HumidityHandler(intent.IntentHandler):
hass = intent_obj.hass
slots = self.async_validate_slots(intent_obj.slots)
state = hass.helpers.intent.async_match_state(
slots["name"]["value"],
[state for state in hass.states.async_all() if state.domain == DOMAIN],
slots["name"]["value"], hass.states.async_all(DOMAIN)
)
service_data = {ATTR_ENTITY_ID: state.entity_id}
@ -87,7 +86,7 @@ class SetModeHandler(intent.IntentHandler):
slots = self.async_validate_slots(intent_obj.slots)
state = hass.helpers.intent.async_match_state(
slots["name"]["value"],
[state for state in hass.states.async_all() if state.domain == DOMAIN],
hass.states.async_all(DOMAIN),
)
service_data = {ATTR_ENTITY_ID: state.entity_id}

View File

@ -39,8 +39,7 @@ class SetIntentHandler(intent.IntentHandler):
hass = intent_obj.hass
slots = self.async_validate_slots(intent_obj.slots)
state = hass.helpers.intent.async_match_state(
slots["name"]["value"],
[state for state in hass.states.async_all() if state.domain == DOMAIN],
slots["name"]["value"], hass.states.async_all(DOMAIN)
)
service_data = {ATTR_ENTITY_ID: state.entity_id}

View File

@ -183,10 +183,7 @@ async def handle_webhook(hass, webhook_id, request):
response = []
for person in hass.states.async_all():
if person.domain != "person":
continue
for person in hass.states.async_all("person"):
if "latitude" in person.attributes and "longitude" in person.attributes:
response.append(
{

View File

@ -918,17 +918,29 @@ class StateMachine:
if state.domain in domain_filter
]
def all(self) -> List[State]:
def all(self, domain_filter: Optional[Union[str, Iterable]] = None) -> List[State]:
"""Create a list of all states."""
return run_callback_threadsafe(self._loop, self.async_all).result()
return run_callback_threadsafe(
self._loop, self.async_all, domain_filter
).result()
@callback
def async_all(self) -> List[State]:
"""Create a list of all states.
def async_all(
self, domain_filter: Optional[Union[str, Iterable]] = None
) -> List[State]:
"""Create a list of all states matching the filter.
This method must be run in the event loop.
"""
return list(self._states.values())
if domain_filter is None:
return list(self._states.values())
if isinstance(domain_filter, str):
domain_filter = (domain_filter.lower(),)
return [
state for state in self._states.values() if state.domain in domain_filter
]
def get(self, entity_id: str) -> Optional[State]:
"""Retrieve state of entity_id or None if not found.

View File

@ -459,8 +459,7 @@ class DomainStates:
sorted(
(
_wrap_state(self._hass, state)
for state in self._hass.states.async_all()
if state.domain == self._domain
for state in self._hass.states.async_all(self._domain)
),
key=lambda state: state.entity_id,
)

View File

@ -1454,3 +1454,26 @@ async def test_chained_logging_misses_log_timeout(hass, caplog):
await hass.async_block_till_done()
assert "_task_chain_" not in caplog.text
async def test_async_all(hass):
"""Test async_all."""
hass.states.async_set("switch.link", "on")
hass.states.async_set("light.bowl", "on")
hass.states.async_set("light.frog", "on")
hass.states.async_set("vacuum.floor", "on")
assert {state.entity_id for state in hass.states.async_all()} == {
"switch.link",
"light.bowl",
"light.frog",
"vacuum.floor",
}
assert {state.entity_id for state in hass.states.async_all("light")} == {
"light.bowl",
"light.frog",
}
assert {
state.entity_id for state in hass.states.async_all(["light", "switch"])
} == {"light.bowl", "light.frog", "switch.link"}