Use entity sources to find related entities in Search (#51966)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
pull/52014/head
Bram Kragten 2021-06-19 13:25:26 +02:00 committed by GitHub
parent 1d941284ff
commit 34a44b9bec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 5 deletions

View File

@ -8,6 +8,7 @@ from homeassistant.components import automation, group, script, websocket_api
from homeassistant.components.homeassistant import scene from homeassistant.components.homeassistant import scene
from homeassistant.core import HomeAssistant, callback, split_entity_id from homeassistant.core import HomeAssistant, callback, split_entity_id
from homeassistant.helpers import device_registry, entity_registry from homeassistant.helpers import device_registry, entity_registry
from homeassistant.helpers.entity import entity_sources as get_entity_sources
DOMAIN = "search" DOMAIN = "search"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -44,6 +45,7 @@ def websocket_search_related(hass, connection, msg):
hass, hass,
device_registry.async_get(hass), device_registry.async_get(hass),
entity_registry.async_get(hass), entity_registry.async_get(hass),
get_entity_sources(hass),
) )
connection.send_result( connection.send_result(
msg["id"], searcher.async_search(msg["item_type"], msg["item_id"]) msg["id"], searcher.async_search(msg["item_type"], msg["item_id"])
@ -69,11 +71,13 @@ class Searcher:
hass: HomeAssistant, hass: HomeAssistant,
device_reg: device_registry.DeviceRegistry, device_reg: device_registry.DeviceRegistry,
entity_reg: entity_registry.EntityRegistry, entity_reg: entity_registry.EntityRegistry,
entity_sources: "dict[str, dict[str, str]]",
) -> None: ) -> None:
"""Search results.""" """Search results."""
self.hass = hass self.hass = hass
self._device_reg = device_reg self._device_reg = device_reg
self._entity_reg = entity_reg self._entity_reg = entity_reg
self._sources = entity_sources
self.results = defaultdict(set) self.results = defaultdict(set)
self._to_resolve = deque() self._to_resolve = deque()
@ -184,6 +188,10 @@ class Searcher:
if entity_entry.config_entry_id is not None: if entity_entry.config_entry_id is not None:
self._add_or_resolve("config_entry", entity_entry.config_entry_id) self._add_or_resolve("config_entry", entity_entry.config_entry_id)
else:
source = self._sources.get(entity_id)
if source is not None and "config_entry" in source:
self._add_or_resolve("config_entry", source["config_entry"])
domain = split_entity_id(entity_id)[0] domain = split_entity_id(entity_id)[0]

View File

@ -3,6 +3,7 @@ from homeassistant.components import search
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry as ar, area_registry as ar,
device_registry as dr, device_registry as dr,
entity,
entity_registry as er, entity_registry as er,
) )
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -10,6 +11,18 @@ from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401 from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401
MOCK_ENTITY_SOURCES = {
"light.platform_config_source": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "wled",
},
"light.config_entry_source": {
"source": entity.SOURCE_CONFIG_ENTRY,
"config_entry": "config_entry_id",
"domain": "wled",
},
}
async def test_search(hass): async def test_search(hass):
"""Test that search works.""" """Test that search works."""
@ -48,6 +61,18 @@ async def test_search(hass):
device_id=wled_device.id, device_id=wled_device.id,
) )
entity_sources = {
"light.wled_platform_config_source": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "wled",
},
"light.wled_config_entry_source": {
"source": entity.SOURCE_CONFIG_ENTRY,
"config_entry": wled_config_entry.entry_id,
"domain": "wled",
},
}
# Non related info. # Non related info.
kitchen_area = area_reg.async_create("Kitchen") kitchen_area = area_reg.async_create("Kitchen")
@ -221,7 +246,7 @@ async def test_search(hass):
("automation", "automation.wled_entity"), ("automation", "automation.wled_entity"),
("automation", "automation.wled_device"), ("automation", "automation.wled_device"),
): ):
searcher = search.Searcher(hass, device_reg, entity_reg) searcher = search.Searcher(hass, device_reg, entity_reg, entity_sources)
results = searcher.async_search(search_type, search_id) results = searcher.async_search(search_type, search_id)
# Add the item we searched for, it's omitted from results # Add the item we searched for, it's omitted from results
results.setdefault(search_type, set()).add(search_id) results.setdefault(search_type, set()).add(search_id)
@ -254,7 +279,7 @@ async def test_search(hass):
("scene", "scene.scene_wled_hue"), ("scene", "scene.scene_wled_hue"),
("group", "group.wled_hue"), ("group", "group.wled_hue"),
): ):
searcher = search.Searcher(hass, device_reg, entity_reg) searcher = search.Searcher(hass, device_reg, entity_reg, entity_sources)
results = searcher.async_search(search_type, search_id) results = searcher.async_search(search_type, search_id)
# Add the item we searched for, it's omitted from results # Add the item we searched for, it's omitted from results
results.setdefault(search_type, set()).add(search_id) results.setdefault(search_type, set()).add(search_id)
@ -276,9 +301,14 @@ async def test_search(hass):
("script", "script.non_existing"), ("script", "script.non_existing"),
("automation", "automation.non_existing"), ("automation", "automation.non_existing"),
): ):
searcher = search.Searcher(hass, device_reg, entity_reg) searcher = search.Searcher(hass, device_reg, entity_reg, entity_sources)
assert searcher.async_search(search_type, search_id) == {} assert searcher.async_search(search_type, search_id) == {}
searcher = search.Searcher(hass, device_reg, entity_reg, entity_sources)
assert searcher.async_search("entity", "light.wled_config_entry_source") == {
"config_entry": {wled_config_entry.entry_id},
}
async def test_area_lookup(hass): async def test_area_lookup(hass):
"""Test area based lookup.""" """Test area based lookup."""
@ -326,13 +356,13 @@ async def test_area_lookup(hass):
}, },
) )
searcher = search.Searcher(hass, device_reg, entity_reg) searcher = search.Searcher(hass, device_reg, entity_reg, MOCK_ENTITY_SOURCES)
assert searcher.async_search("area", living_room_area.id) == { assert searcher.async_search("area", living_room_area.id) == {
"script": {"script.wled"}, "script": {"script.wled"},
"automation": {"automation.area_turn_on"}, "automation": {"automation.area_turn_on"},
} }
searcher = search.Searcher(hass, device_reg, entity_reg) searcher = search.Searcher(hass, device_reg, entity_reg, MOCK_ENTITY_SOURCES)
assert searcher.async_search("automation", "automation.area_turn_on") == { assert searcher.async_search("automation", "automation.area_turn_on") == {
"area": {living_room_area.id}, "area": {living_room_area.id},
} }