Add a service to reload config entries that can easily be called though automations (#46762)
parent
6fb0e49335
commit
08db262972
|
@ -21,15 +21,30 @@ from homeassistant.const import (
|
|||
import homeassistant.core as ha
|
||||
from homeassistant.exceptions import HomeAssistantError, Unauthorized, UnknownUser
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.service import async_extract_referenced_entity_ids
|
||||
from homeassistant.helpers.service import (
|
||||
async_extract_config_entry_ids,
|
||||
async_extract_referenced_entity_ids,
|
||||
)
|
||||
|
||||
ATTR_ENTRY_ID = "entry_id"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DOMAIN = ha.DOMAIN
|
||||
SERVICE_RELOAD_CORE_CONFIG = "reload_core_config"
|
||||
SERVICE_RELOAD_CONFIG_ENTRY = "reload_config_entry"
|
||||
SERVICE_CHECK_CONFIG = "check_config"
|
||||
SERVICE_UPDATE_ENTITY = "update_entity"
|
||||
SERVICE_SET_LOCATION = "set_location"
|
||||
SCHEMA_UPDATE_ENTITY = vol.Schema({ATTR_ENTITY_ID: cv.entity_ids})
|
||||
SCHEMA_RELOAD_CONFIG_ENTRY = vol.All(
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Optional(ATTR_ENTRY_ID): str,
|
||||
**cv.ENTITY_SERVICE_FIELDS,
|
||||
},
|
||||
),
|
||||
cv.has_at_least_one_key(ATTR_ENTRY_ID, *cv.ENTITY_SERVICE_FIELDS),
|
||||
)
|
||||
|
||||
|
||||
async def async_setup(hass: ha.HomeAssistant, config: dict) -> bool:
|
||||
|
@ -203,4 +218,26 @@ async def async_setup(hass: ha.HomeAssistant, config: dict) -> bool:
|
|||
vol.Schema({ATTR_LATITUDE: cv.latitude, ATTR_LONGITUDE: cv.longitude}),
|
||||
)
|
||||
|
||||
async def async_handle_reload_config_entry(call):
|
||||
"""Service handler for reloading a config entry."""
|
||||
reload_entries = set()
|
||||
if ATTR_ENTRY_ID in call.data:
|
||||
reload_entries.add(call.data[ATTR_ENTRY_ID])
|
||||
reload_entries.update(await async_extract_config_entry_ids(hass, call))
|
||||
if not reload_entries:
|
||||
raise ValueError("There were no matching config entries to reload")
|
||||
await asyncio.gather(
|
||||
*[
|
||||
hass.config_entries.async_reload(config_entry_id)
|
||||
for config_entry_id in reload_entries
|
||||
]
|
||||
)
|
||||
|
||||
hass.helpers.service.async_register_admin_service(
|
||||
ha.DOMAIN,
|
||||
SERVICE_RELOAD_CONFIG_ENTRY,
|
||||
async_handle_reload_config_entry,
|
||||
schema=SCHEMA_RELOAD_CONFIG_ENTRY,
|
||||
)
|
||||
|
||||
return True
|
||||
|
|
|
@ -58,3 +58,19 @@ update_entity:
|
|||
description: Force one or more entities to update its data
|
||||
target:
|
||||
entity: {}
|
||||
|
||||
reload_config_entry:
|
||||
name: Reload config entry
|
||||
description: Reload a config entry that matches a target.
|
||||
target:
|
||||
entity: {}
|
||||
device: {}
|
||||
fields:
|
||||
entry_id:
|
||||
advanced: true
|
||||
name: Config entry id
|
||||
description: A configuration entry id
|
||||
required: false
|
||||
example: 8955375327824e14ba89e4b29cc3ec9a
|
||||
selector:
|
||||
text:
|
||||
|
|
|
@ -11,9 +11,9 @@ from typing import (
|
|||
Awaitable,
|
||||
Callable,
|
||||
Iterable,
|
||||
Tuple,
|
||||
Optional,
|
||||
TypedDict,
|
||||
cast,
|
||||
Union,
|
||||
)
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -78,6 +78,29 @@ class ServiceParams(TypedDict):
|
|||
target: dict | None
|
||||
|
||||
|
||||
class ServiceTargetSelector:
|
||||
"""Class to hold a target selector for a service."""
|
||||
|
||||
def __init__(self, service_call: ha.ServiceCall):
|
||||
"""Extract ids from service call data."""
|
||||
entity_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_ENTITY_ID)
|
||||
device_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_DEVICE_ID)
|
||||
area_ids: Optional[Union[str, list]] = service_call.data.get(ATTR_AREA_ID)
|
||||
|
||||
self.entity_ids = (
|
||||
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
|
||||
)
|
||||
self.device_ids = (
|
||||
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
|
||||
)
|
||||
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
|
||||
|
||||
@property
|
||||
def has_any_selector(self) -> bool:
|
||||
"""Determine if any selectors are present."""
|
||||
return bool(self.entity_ids or self.device_ids or self.area_ids)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SelectedEntities:
|
||||
"""Class to hold the selected entities."""
|
||||
|
@ -93,6 +116,9 @@ class SelectedEntities:
|
|||
missing_devices: set[str] = dataclasses.field(default_factory=set)
|
||||
missing_areas: set[str] = dataclasses.field(default_factory=set)
|
||||
|
||||
# Referenced devices
|
||||
referenced_devices: set[str] = dataclasses.field(default_factory=set)
|
||||
|
||||
def log_missing(self, missing_entities: set[str]) -> None:
|
||||
"""Log about missing items."""
|
||||
parts = []
|
||||
|
@ -293,98 +319,88 @@ async def async_extract_entity_ids(
|
|||
return referenced.referenced | referenced.indirectly_referenced
|
||||
|
||||
|
||||
def _has_match(ids: Optional[Union[str, list]]) -> bool:
|
||||
"""Check if ids can match anything."""
|
||||
return ids not in (None, ENTITY_MATCH_NONE)
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_extract_referenced_entity_ids(
|
||||
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
|
||||
) -> SelectedEntities:
|
||||
"""Extract referenced entity IDs from a service call."""
|
||||
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
|
||||
device_ids = service_call.data.get(ATTR_DEVICE_ID)
|
||||
area_ids = service_call.data.get(ATTR_AREA_ID)
|
||||
|
||||
selects_entity_ids = entity_ids not in (None, ENTITY_MATCH_NONE)
|
||||
selects_device_ids = device_ids not in (None, ENTITY_MATCH_NONE)
|
||||
selects_area_ids = area_ids not in (None, ENTITY_MATCH_NONE)
|
||||
|
||||
selector = ServiceTargetSelector(service_call)
|
||||
selected = SelectedEntities()
|
||||
|
||||
if not selects_entity_ids and not selects_device_ids and not selects_area_ids:
|
||||
if not selector.has_any_selector:
|
||||
return selected
|
||||
|
||||
if selects_entity_ids:
|
||||
assert entity_ids is not None
|
||||
entity_ids = selector.entity_ids
|
||||
if expand_group:
|
||||
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
|
||||
|
||||
# Entity ID attr can be a list or a string
|
||||
if isinstance(entity_ids, str):
|
||||
entity_ids = [entity_ids]
|
||||
selected.referenced.update(entity_ids)
|
||||
|
||||
if expand_group:
|
||||
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
|
||||
|
||||
selected.referenced.update(entity_ids)
|
||||
|
||||
if not selects_device_ids and not selects_area_ids:
|
||||
if not selector.device_ids and not selector.area_ids:
|
||||
return selected
|
||||
|
||||
area_reg, dev_reg, ent_reg = cast(
|
||||
Tuple[
|
||||
area_registry.AreaRegistry,
|
||||
device_registry.DeviceRegistry,
|
||||
entity_registry.EntityRegistry,
|
||||
],
|
||||
await asyncio.gather(
|
||||
area_registry.async_get_registry(hass),
|
||||
device_registry.async_get_registry(hass),
|
||||
entity_registry.async_get_registry(hass),
|
||||
),
|
||||
)
|
||||
ent_reg = entity_registry.async_get(hass)
|
||||
dev_reg = device_registry.async_get(hass)
|
||||
area_reg = area_registry.async_get(hass)
|
||||
|
||||
picked_devices = set()
|
||||
for device_id in selector.device_ids:
|
||||
if device_id not in dev_reg.devices:
|
||||
selected.missing_devices.add(device_id)
|
||||
|
||||
if selects_device_ids:
|
||||
if isinstance(device_ids, str):
|
||||
picked_devices = {device_ids}
|
||||
else:
|
||||
assert isinstance(device_ids, list)
|
||||
picked_devices = set(device_ids)
|
||||
for area_id in selector.area_ids:
|
||||
if area_id not in area_reg.areas:
|
||||
selected.missing_areas.add(area_id)
|
||||
|
||||
for device_id in picked_devices:
|
||||
if device_id not in dev_reg.devices:
|
||||
selected.missing_devices.add(device_id)
|
||||
# Find devices for this area
|
||||
selected.referenced_devices.update(selector.device_ids)
|
||||
for device_entry in dev_reg.devices.values():
|
||||
if device_entry.area_id in selector.area_ids:
|
||||
selected.referenced_devices.add(device_entry.id)
|
||||
|
||||
if selects_area_ids:
|
||||
assert area_ids is not None
|
||||
|
||||
if isinstance(area_ids, str):
|
||||
area_lookup = {area_ids}
|
||||
else:
|
||||
area_lookup = set(area_ids)
|
||||
|
||||
for area_id in area_lookup:
|
||||
if area_id not in area_reg.areas:
|
||||
selected.missing_areas.add(area_id)
|
||||
continue
|
||||
|
||||
# Find entities tied to an area
|
||||
for entity_entry in ent_reg.entities.values():
|
||||
if entity_entry.area_id in area_lookup:
|
||||
selected.indirectly_referenced.add(entity_entry.entity_id)
|
||||
|
||||
# Find devices for this area
|
||||
for device_entry in dev_reg.devices.values():
|
||||
if device_entry.area_id in area_lookup:
|
||||
picked_devices.add(device_entry.id)
|
||||
|
||||
if not picked_devices:
|
||||
if not selector.area_ids and not selected.referenced_devices:
|
||||
return selected
|
||||
|
||||
for entity_entry in ent_reg.entities.values():
|
||||
if not entity_entry.area_id and entity_entry.device_id in picked_devices:
|
||||
selected.indirectly_referenced.add(entity_entry.entity_id)
|
||||
for ent_entry in ent_reg.entities.values():
|
||||
if ent_entry.area_id in selector.area_ids or (
|
||||
not ent_entry.area_id and ent_entry.device_id in selected.referenced_devices
|
||||
):
|
||||
selected.indirectly_referenced.add(ent_entry.entity_id)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_extract_config_entry_ids(
|
||||
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
|
||||
) -> set:
|
||||
"""Extract referenced config entry ids from a service call."""
|
||||
referenced = await async_extract_referenced_entity_ids(
|
||||
hass, service_call, expand_group
|
||||
)
|
||||
ent_reg = entity_registry.async_get(hass)
|
||||
dev_reg = device_registry.async_get(hass)
|
||||
config_entry_ids: set[str] = set()
|
||||
|
||||
# Some devices may have no entities
|
||||
for device_id in referenced.referenced_devices:
|
||||
if device_id in dev_reg.devices:
|
||||
device = dev_reg.async_get(device_id)
|
||||
if device is not None:
|
||||
config_entry_ids.update(device.config_entries)
|
||||
|
||||
for entity_id in referenced.referenced | referenced.indirectly_referenced:
|
||||
entry = ent_reg.async_get(entity_id)
|
||||
if entry is not None and entry.config_entry_id is not None:
|
||||
config_entry_ids.add(entry.config_entry_id)
|
||||
|
||||
return config_entry_ids
|
||||
|
||||
|
||||
def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JSON_TYPE:
|
||||
"""Load services file for an integration."""
|
||||
try:
|
||||
|
|
|
@ -11,6 +11,7 @@ import yaml
|
|||
from homeassistant import config
|
||||
import homeassistant.components as comps
|
||||
from homeassistant.components.homeassistant import (
|
||||
ATTR_ENTRY_ID,
|
||||
SERVICE_CHECK_CONFIG,
|
||||
SERVICE_RELOAD_CORE_CONFIG,
|
||||
SERVICE_SET_LOCATION,
|
||||
|
@ -34,9 +35,11 @@ from homeassistant.helpers import entity
|
|||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
async_capture_events,
|
||||
async_mock_service,
|
||||
get_test_home_assistant,
|
||||
mock_registry,
|
||||
mock_service,
|
||||
patch_yaml_files,
|
||||
)
|
||||
|
@ -385,3 +388,62 @@ async def test_not_allowing_recursion(hass, caplog):
|
|||
f"Called service homeassistant.{service} with invalid entities homeassistant.light"
|
||||
in caplog.text
|
||||
), service
|
||||
|
||||
|
||||
async def test_reload_config_entry_by_entity_id(hass):
|
||||
"""Test being able to reload a config entry by entity_id."""
|
||||
await async_setup_component(hass, "homeassistant", {})
|
||||
entity_reg = mock_registry(hass)
|
||||
entry1 = MockConfigEntry(domain="mockdomain")
|
||||
entry1.add_to_hass(hass)
|
||||
entry2 = MockConfigEntry(domain="mockdomain")
|
||||
entry2.add_to_hass(hass)
|
||||
reg_entity1 = entity_reg.async_get_or_create(
|
||||
"binary_sensor", "powerwall", "battery_charging", config_entry=entry1
|
||||
)
|
||||
reg_entity2 = entity_reg.async_get_or_create(
|
||||
"binary_sensor", "powerwall", "battery_status", config_entry=entry2
|
||||
)
|
||||
with patch(
|
||||
"homeassistant.config_entries.ConfigEntries.async_reload",
|
||||
return_value=None,
|
||||
) as mock_reload:
|
||||
await hass.services.async_call(
|
||||
"homeassistant",
|
||||
"reload_config_entry",
|
||||
{"entity_id": f"{reg_entity1.entity_id},{reg_entity2.entity_id}"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(mock_reload.mock_calls) == 2
|
||||
assert {mock_reload.mock_calls[0][1][0], mock_reload.mock_calls[1][1][0]} == {
|
||||
entry1.entry_id,
|
||||
entry2.entry_id,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await hass.services.async_call(
|
||||
"homeassistant",
|
||||
"reload_config_entry",
|
||||
{"entity_id": "unknown.entity_id"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_reload_config_entry_by_entry_id(hass):
|
||||
"""Test being able to reload a config entry by config entry id."""
|
||||
await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
with patch(
|
||||
"homeassistant.config_entries.ConfigEntries.async_reload",
|
||||
return_value=None,
|
||||
) as mock_reload:
|
||||
await hass.services.async_call(
|
||||
"homeassistant",
|
||||
"reload_config_entry",
|
||||
{ATTR_ENTRY_ID: "8955375327824e14ba89e4b29cc3ec9a"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(mock_reload.mock_calls) == 1
|
||||
assert mock_reload.mock_calls[0][1][0] == "8955375327824e14ba89e4b29cc3ec9a"
|
||||
|
|
|
@ -1015,3 +1015,28 @@ async def test_async_extract_entities_warn_referenced(hass, caplog):
|
|||
"Unable to find referenced areas non-existent-area, devices non-existent-device, entities non.existent"
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
|
||||
async def test_async_extract_config_entry_ids(hass):
|
||||
"""Test we can find devices that have no entities."""
|
||||
|
||||
device_no_entities = dev_reg.DeviceEntry(
|
||||
id="device-no-entities", config_entries={"abc"}
|
||||
)
|
||||
|
||||
call = ha.ServiceCall(
|
||||
"homeassistant",
|
||||
"reload_config_entry",
|
||||
{
|
||||
"device_id": "device-no-entities",
|
||||
},
|
||||
)
|
||||
|
||||
mock_device_registry(
|
||||
hass,
|
||||
{
|
||||
device_no_entities.id: device_no_entities,
|
||||
},
|
||||
)
|
||||
|
||||
assert await service.async_extract_config_entry_ids(hass, call) == {"abc"}
|
||||
|
|
Loading…
Reference in New Issue