Search for areas used in automations and scripts (#48499)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
Co-authored-by: Bram Kragten <mail@bramkragten.nl>
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
pull/48524/head
Bram Kragten 2021-03-31 00:01:56 +02:00 committed by GitHub
parent 309c3a8d82
commit d1a1e70726
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 172 additions and 8 deletions

View File

@ -176,6 +176,37 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
return list(automation_entity.referenced_devices)
@callback
def automations_with_area(hass: HomeAssistant, area_id: str) -> list[str]:
"""Return all automations that reference the area."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
return [
automation_entity.entity_id
for automation_entity in component.entities
if area_id in automation_entity.referenced_areas
]
@callback
def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
"""Return all areas in an automation."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
automation_entity = component.get_entity(entity_id)
if automation_entity is None:
return []
return list(automation_entity.referenced_areas)
async def async_setup(hass, config):
"""Set up all automations."""
# Local import to avoid circular import
@ -293,6 +324,11 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
"""Return True if entity is on."""
return self._async_detach_triggers is not None or self._is_enabled
@property
def referenced_areas(self):
"""Return a set of referenced areas."""
return self.action_script.referenced_areas
@property
def referenced_devices(self):
"""Return a set of referenced devices."""

View File

@ -165,6 +165,37 @@ def devices_in_script(hass: HomeAssistant, entity_id: str) -> list[str]:
return list(script_entity.script.referenced_devices)
@callback
def scripts_with_area(hass: HomeAssistant, area_id: str) -> list[str]:
"""Return all scripts that reference the area."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
return [
script_entity.entity_id
for script_entity in component.entities
if area_id in script_entity.script.referenced_areas
]
@callback
def areas_in_script(hass: HomeAssistant, entity_id: str) -> list[str]:
"""Return all areas in a script."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
script_entity = component.get_entity(entity_id)
if script_entity is None:
return []
return list(script_entity.script.referenced_areas)
async def async_setup(hass, config):
"""Load the scripts from the configuration."""
hass.data[DOMAIN] = component = EntityComponent(_LOGGER, DOMAIN, hass)

View File

@ -38,12 +38,12 @@ async def async_setup(hass: HomeAssistant, config: dict):
vol.Required("item_id"): str,
}
)
async def websocket_search_related(hass, connection, msg):
def websocket_search_related(hass, connection, msg):
"""Handle search."""
searcher = Searcher(
hass,
await device_registry.async_get_registry(hass),
await entity_registry.async_get_registry(hass),
device_registry.async_get(hass),
entity_registry.async_get(hass),
)
connection.send_result(
msg["id"], searcher.async_search(msg["item_type"], msg["item_id"])
@ -127,6 +127,12 @@ class Searcher:
):
self._add_or_resolve("entity", entity_entry.entity_id)
for entity_id in script.scripts_with_area(self.hass, area_id):
self._add_or_resolve("entity", entity_id)
for entity_id in automation.automations_with_area(self.hass, area_id):
self._add_or_resolve("entity", entity_id)
@callback
def _resolve_device(self, device_id) -> None:
"""Resolve a device."""
@ -198,6 +204,9 @@ class Searcher:
for device in automation.devices_in_automation(self.hass, automation_entity_id):
self._add_or_resolve("device", device)
for area in automation.areas_in_automation(self.hass, automation_entity_id):
self._add_or_resolve("area", area)
@callback
def _resolve_script(self, script_entity_id) -> None:
"""Resolve a script.
@ -210,6 +219,9 @@ class Searcher:
for device in script.devices_in_script(self.hass, script_entity_id):
self._add_or_resolve("device", device)
for area in script.areas_in_script(self.hass, script_entity_id):
self._add_or_resolve("area", area)
@callback
def _resolve_group(self, group_entity_id) -> None:
"""Resolve a group.

View File

@ -17,6 +17,7 @@ from homeassistant import exceptions
from homeassistant.components import device_automation, scene
from homeassistant.components.logger import LOGSEVERITY
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
CONF_ALIAS,
@ -900,10 +901,10 @@ def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) ->
return
if isinstance(item_ids, str):
item_ids = [item_ids]
for item_id in item_ids:
found.add(item_id)
found.add(item_ids)
else:
for item_id in item_ids:
found.add(item_id)
class Script:
@ -970,6 +971,7 @@ class Script:
self._choose_data: dict[int, dict[str, Any]] = {}
self._referenced_entities: set[str] | None = None
self._referenced_devices: set[str] | None = None
self._referenced_areas: set[str] | None = None
self.variables = variables
self._variables_dynamic = template.is_complex(variables)
if self._variables_dynamic:
@ -1031,6 +1033,28 @@ class Script:
"""Return true if the current mode support max."""
return self.script_mode in (SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED)
@property
def referenced_areas(self):
"""Return a set of referenced areas."""
if self._referenced_areas is not None:
return self._referenced_areas
referenced: set[str] = set()
for step in self.sequence:
action = cv.determine_script_action(step)
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
for data in (
step.get(CONF_TARGET),
step.get(service.CONF_SERVICE_DATA),
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
):
_referenced_extract_ids(data, ATTR_AREA_ID, referenced)
self._referenced_areas = referenced
return referenced
@property
def referenced_devices(self):
"""Return a set of referenced devices."""
@ -1044,7 +1068,6 @@ class Script:
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
for data in (
step,
step.get(CONF_TARGET),
step.get(service.CONF_SERVICE_DATA),
step.get(service.CONF_SERVICE_DATA_TEMPLATE),

View File

@ -193,6 +193,10 @@ async def test_search(hass):
},
)
# Ensure automations set up correctly.
assert hass.states.get("automation.wled_entity") is not None
assert hass.states.get("automation.wled_device") is not None
# Explore the graph from every node and make sure we find the same results
expected = {
"config_entry": {wled_config_entry.entry_id},
@ -276,6 +280,64 @@ async def test_search(hass):
assert searcher.async_search(search_type, search_id) == {}
async def test_area_lookup(hass):
"""Test area based lookup."""
area_reg = ar.async_get(hass)
device_reg = dr.async_get(hass)
entity_reg = er.async_get(hass)
living_room_area = area_reg.async_create("Living Room")
await async_setup_component(
hass,
"script",
{
"script": {
"wled": {
"sequence": [
{
"service": "light.turn_on",
"target": {"area_id": living_room_area.id},
},
]
},
}
},
)
assert await async_setup_component(
hass,
"automation",
{
"automation": [
{
"alias": "area_turn_on",
"trigger": {"platform": "template", "value_template": "true"},
"action": [
{
"service": "light.turn_on",
"data": {
"area_id": living_room_area.id,
},
},
],
},
]
},
)
searcher = search.Searcher(hass, device_reg, entity_reg)
assert searcher.async_search("area", living_room_area.id) == {
"script": {"script.wled"},
"automation": {"automation.area_turn_on"},
}
searcher = search.Searcher(hass, device_reg, entity_reg)
assert searcher.async_search("automation", "automation.area_turn_on") == {
"area": {living_room_area.id},
}
async def test_ws_api(hass, hass_ws_client):
"""Test WS API."""
assert await async_setup_component(hass, "search", {})