Add a service to reload config entries that can easily be called though automations (#46762)

pull/47767/head
J. Nick Koston 2021-03-17 18:27:21 -10:00 committed by GitHub
parent 6fb0e49335
commit 08db262972
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 228 additions and 72 deletions

View File

@ -21,15 +21,30 @@ from homeassistant.const import (
import homeassistant.core as ha
from homeassistant.exceptions import HomeAssistantError, Unauthorized, UnknownUser
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_extract_referenced_entity_ids
from homeassistant.helpers.service import (
async_extract_config_entry_ids,
async_extract_referenced_entity_ids,
)
ATTR_ENTRY_ID = "entry_id"
_LOGGER = logging.getLogger(__name__)
DOMAIN = ha.DOMAIN
SERVICE_RELOAD_CORE_CONFIG = "reload_core_config"
SERVICE_RELOAD_CONFIG_ENTRY = "reload_config_entry"
SERVICE_CHECK_CONFIG = "check_config"
SERVICE_UPDATE_ENTITY = "update_entity"
SERVICE_SET_LOCATION = "set_location"
SCHEMA_UPDATE_ENTITY = vol.Schema({ATTR_ENTITY_ID: cv.entity_ids})
SCHEMA_RELOAD_CONFIG_ENTRY = vol.All(
vol.Schema(
{
vol.Optional(ATTR_ENTRY_ID): str,
**cv.ENTITY_SERVICE_FIELDS,
},
),
cv.has_at_least_one_key(ATTR_ENTRY_ID, *cv.ENTITY_SERVICE_FIELDS),
)
async def async_setup(hass: ha.HomeAssistant, config: dict) -> bool:
@ -203,4 +218,26 @@ async def async_setup(hass: ha.HomeAssistant, config: dict) -> bool:
vol.Schema({ATTR_LATITUDE: cv.latitude, ATTR_LONGITUDE: cv.longitude}),
)
async def async_handle_reload_config_entry(call):
"""Service handler for reloading a config entry."""
reload_entries = set()
if ATTR_ENTRY_ID in call.data:
reload_entries.add(call.data[ATTR_ENTRY_ID])
reload_entries.update(await async_extract_config_entry_ids(hass, call))
if not reload_entries:
raise ValueError("There were no matching config entries to reload")
await asyncio.gather(
*[
hass.config_entries.async_reload(config_entry_id)
for config_entry_id in reload_entries
]
)
hass.helpers.service.async_register_admin_service(
ha.DOMAIN,
SERVICE_RELOAD_CONFIG_ENTRY,
async_handle_reload_config_entry,
schema=SCHEMA_RELOAD_CONFIG_ENTRY,
)
return True

View File

@ -58,3 +58,19 @@ update_entity:
description: Force one or more entities to update its data
target:
entity: {}
reload_config_entry:
name: Reload config entry
description: Reload a config entry that matches a target.
target:
entity: {}
device: {}
fields:
entry_id:
advanced: true
name: Config entry id
description: A configuration entry id
required: false
example: 8955375327824e14ba89e4b29cc3ec9a
selector:
text:

View File

@ -11,9 +11,9 @@ from typing import (
Awaitable,
Callable,
Iterable,
Tuple,
Optional,
TypedDict,
cast,
Union,
)
import voluptuous as vol
@ -78,6 +78,29 @@ class ServiceParams(TypedDict):
target: dict | None
class ServiceTargetSelector:
"""Class to hold a target selector for a service."""
def __init__(self, service_call: ha.ServiceCall):
"""Extract ids from service call data."""
entity_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_ENTITY_ID)
device_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_DEVICE_ID)
area_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_AREA_ID)
self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
)
self.device_ids = (
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
)
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
@property
def has_any_selector(self) -> bool:
"""Determine if any selectors are present."""
return bool(self.entity_ids or self.device_ids or self.area_ids)
@dataclasses.dataclass
class SelectedEntities:
"""Class to hold the selected entities."""
@ -93,6 +116,9 @@ class SelectedEntities:
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
# Referenced devices
referenced_devices: set[str] = dataclasses.field(default_factory=set)
def log_missing(self, missing_entities: set[str]) -> None:
"""Log about missing items."""
parts = []
@ -293,98 +319,88 @@ async def async_extract_entity_ids(
return referenced.referenced | referenced.indirectly_referenced
def _has_match(ids: Optional[Union[str, list]]) -> bool:
"""Check if ids can match anything."""
return ids not in (None, ENTITY_MATCH_NONE)
@bind_hass
async def async_extract_referenced_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a service call."""
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
device_ids = service_call.data.get(ATTR_DEVICE_ID)
area_ids = service_call.data.get(ATTR_AREA_ID)
selects_entity_ids = entity_ids not in (None, ENTITY_MATCH_NONE)
selects_device_ids = device_ids not in (None, ENTITY_MATCH_NONE)
selects_area_ids = area_ids not in (None, ENTITY_MATCH_NONE)
selector = ServiceTargetSelector(service_call)
selected = SelectedEntities()
if not selects_entity_ids and not selects_device_ids and not selects_area_ids:
if not selector.has_any_selector:
return selected
if selects_entity_ids:
assert entity_ids is not None
entity_ids = selector.entity_ids
if expand_group:
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
# Entity ID attr can be a list or a string
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
selected.referenced.update(entity_ids)
if expand_group:
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
selected.referenced.update(entity_ids)
if not selects_device_ids and not selects_area_ids:
if not selector.device_ids and not selector.area_ids:
return selected
area_reg, dev_reg, ent_reg = cast(
Tuple[
area_registry.AreaRegistry,
device_registry.DeviceRegistry,
entity_registry.EntityRegistry,
],
await asyncio.gather(
area_registry.async_get_registry(hass),
device_registry.async_get_registry(hass),
entity_registry.async_get_registry(hass),
),
)
ent_reg = entity_registry.async_get(hass)
dev_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
picked_devices = set()
for device_id in selector.device_ids:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
if selects_device_ids:
if isinstance(device_ids, str):
picked_devices = {device_ids}
else:
assert isinstance(device_ids, list)
picked_devices = set(device_ids)
for area_id in selector.area_ids:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
for device_id in picked_devices:
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)
# Find devices for this area
selected.referenced_devices.update(selector.device_ids)
for device_entry in dev_reg.devices.values():
if device_entry.area_id in selector.area_ids:
selected.referenced_devices.add(device_entry.id)
if selects_area_ids:
assert area_ids is not None
if isinstance(area_ids, str):
area_lookup = {area_ids}
else:
area_lookup = set(area_ids)
for area_id in area_lookup:
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)
continue
# Find entities tied to an area
for entity_entry in ent_reg.entities.values():
if entity_entry.area_id in area_lookup:
selected.indirectly_referenced.add(entity_entry.entity_id)
# Find devices for this area
for device_entry in dev_reg.devices.values():
if device_entry.area_id in area_lookup:
picked_devices.add(device_entry.id)
if not picked_devices:
if not selector.area_ids and not selected.referenced_devices:
return selected
for entity_entry in ent_reg.entities.values():
if not entity_entry.area_id and entity_entry.device_id in picked_devices:
selected.indirectly_referenced.add(entity_entry.entity_id)
for ent_entry in ent_reg.entities.values():
if ent_entry.area_id in selector.area_ids or (
not ent_entry.area_id and ent_entry.device_id in selected.referenced_devices
):
selected.indirectly_referenced.add(ent_entry.entity_id)
return selected
@bind_hass
async def async_extract_config_entry_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> set:
"""Extract referenced config entry ids from a service call."""
referenced = await async_extract_referenced_entity_ids(
hass, service_call, expand_group
)
ent_reg = entity_registry.async_get(hass)
dev_reg = device_registry.async_get(hass)
config_entry_ids: set[str] = set()
# Some devices may have no entities
for device_id in referenced.referenced_devices:
if device_id in dev_reg.devices:
device = dev_reg.async_get(device_id)
if device is not None:
config_entry_ids.update(device.config_entries)
for entity_id in referenced.referenced | referenced.indirectly_referenced:
entry = ent_reg.async_get(entity_id)
if entry is not None and entry.config_entry_id is not None:
config_entry_ids.add(entry.config_entry_id)
return config_entry_ids
def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JSON_TYPE:
"""Load services file for an integration."""
try:

View File

@ -11,6 +11,7 @@ import yaml
from homeassistant import config
import homeassistant.components as comps
from homeassistant.components.homeassistant import (
ATTR_ENTRY_ID,
SERVICE_CHECK_CONFIG,
SERVICE_RELOAD_CORE_CONFIG,
SERVICE_SET_LOCATION,
@ -34,9 +35,11 @@ from homeassistant.helpers import entity
from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
async_capture_events,
async_mock_service,
get_test_home_assistant,
mock_registry,
mock_service,
patch_yaml_files,
)
@ -385,3 +388,62 @@ async def test_not_allowing_recursion(hass, caplog):
f"Called service homeassistant.{service} with invalid entities homeassistant.light"
in caplog.text
), service
async def test_reload_config_entry_by_entity_id(hass):
"""Test being able to reload a config entry by entity_id."""
await async_setup_component(hass, "homeassistant", {})
entity_reg = mock_registry(hass)
entry1 = MockConfigEntry(domain="mockdomain")
entry1.add_to_hass(hass)
entry2 = MockConfigEntry(domain="mockdomain")
entry2.add_to_hass(hass)
reg_entity1 = entity_reg.async_get_or_create(
"binary_sensor", "powerwall", "battery_charging", config_entry=entry1
)
reg_entity2 = entity_reg.async_get_or_create(
"binary_sensor", "powerwall", "battery_status", config_entry=entry2
)
with patch(
"homeassistant.config_entries.ConfigEntries.async_reload",
return_value=None,
) as mock_reload:
await hass.services.async_call(
"homeassistant",
"reload_config_entry",
{"entity_id": f"{reg_entity1.entity_id},{reg_entity2.entity_id}"},
blocking=True,
)
assert len(mock_reload.mock_calls) == 2
assert {mock_reload.mock_calls[0][1][0], mock_reload.mock_calls[1][1][0]} == {
entry1.entry_id,
entry2.entry_id,
}
with pytest.raises(ValueError):
await hass.services.async_call(
"homeassistant",
"reload_config_entry",
{"entity_id": "unknown.entity_id"},
blocking=True,
)
async def test_reload_config_entry_by_entry_id(hass):
"""Test being able to reload a config entry by config entry id."""
await async_setup_component(hass, "homeassistant", {})
with patch(
"homeassistant.config_entries.ConfigEntries.async_reload",
return_value=None,
) as mock_reload:
await hass.services.async_call(
"homeassistant",
"reload_config_entry",
{ATTR_ENTRY_ID: "8955375327824e14ba89e4b29cc3ec9a"},
blocking=True,
)
assert len(mock_reload.mock_calls) == 1
assert mock_reload.mock_calls[0][1][0] == "8955375327824e14ba89e4b29cc3ec9a"

View File

@ -1015,3 +1015,28 @@ async def test_async_extract_entities_warn_referenced(hass, caplog):
"Unable to find referenced areas non-existent-area, devices non-existent-device, entities non.existent"
in caplog.text
)
async def test_async_extract_config_entry_ids(hass):
"""Test we can find devices that have no entities."""
device_no_entities = dev_reg.DeviceEntry(
id="device-no-entities", config_entries={"abc"}
)
call = ha.ServiceCall(
"homeassistant",
"reload_config_entry",
{
"device_id": "device-no-entities",
},
)
mock_device_registry(
hass,
{
device_no_entities.id: device_no_entities,
},
)
assert await service.async_extract_config_entry_ids(hass, call) == {"abc"}