Prevent deleting blueprints which are in use (#78444)

pull/78703/head
Erik Montnemery 2022-09-14 16:47:08 +02:00 committed by Paulus Schoutsen
parent a4749178f1
commit 40c5689507
10 changed files with 193 additions and 9 deletions

View File

@ -9,6 +9,7 @@ import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
from homeassistant.components import blueprint from homeassistant.components import blueprint
from homeassistant.components.blueprint import CONF_USE_BLUEPRINT
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
ATTR_MODE, ATTR_MODE,
@ -20,6 +21,7 @@ from homeassistant.const import (
CONF_EVENT_DATA, CONF_EVENT_DATA,
CONF_ID, CONF_ID,
CONF_MODE, CONF_MODE,
CONF_PATH,
CONF_PLATFORM, CONF_PLATFORM,
CONF_VARIABLES, CONF_VARIABLES,
CONF_ZONE, CONF_ZONE,
@ -224,6 +226,21 @@ def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
return list(automation_entity.referenced_areas) return list(automation_entity.referenced_areas)
@callback
def automations_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str]:
"""Return all automations that reference the blueprint."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
return [
automation_entity.entity_id
for automation_entity in component.entities
if automation_entity.referenced_blueprint == blueprint_path
]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up all automations.""" """Set up all automations."""
hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass) hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass)
@ -346,7 +363,14 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
return self.action_script.referenced_areas return self.action_script.referenced_areas
@property @property
def referenced_devices(self): def referenced_blueprint(self) -> str | None:
"""Return referenced blueprint or None."""
if self._blueprint_inputs is None:
return None
return cast(str, self._blueprint_inputs[CONF_USE_BLUEPRINT][CONF_PATH])
@property
def referenced_devices(self) -> set[str]:
"""Return a set of referenced devices.""" """Return a set of referenced devices."""
if self._referenced_devices is not None: if self._referenced_devices is not None:
return self._referenced_devices return self._referenced_devices

View File

@ -8,8 +8,15 @@ from .const import DOMAIN, LOGGER
DATA_BLUEPRINTS = "automation_blueprints" DATA_BLUEPRINTS = "automation_blueprints"
def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool:
"""Return True if any automation references the blueprint."""
from . import automations_with_blueprint # pylint: disable=import-outside-toplevel
return len(automations_with_blueprint(hass, blueprint_path)) > 0
@singleton(DATA_BLUEPRINTS) @singleton(DATA_BLUEPRINTS)
@callback @callback
def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints: def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints:
"""Get automation blueprints.""" """Get automation blueprints."""
return blueprint.DomainBlueprints(hass, DOMAIN, LOGGER) return blueprint.DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use)

View File

@ -3,7 +3,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import websocket_api from . import websocket_api
from .const import DOMAIN # noqa: F401 from .const import CONF_USE_BLUEPRINT, DOMAIN # noqa: F401
from .errors import ( # noqa: F401 from .errors import ( # noqa: F401
BlueprintException, BlueprintException,
BlueprintWithNameException, BlueprintWithNameException,

View File

@ -91,3 +91,11 @@ class FileAlreadyExists(BlueprintWithNameException):
def __init__(self, domain: str, blueprint_name: str) -> None: def __init__(self, domain: str, blueprint_name: str) -> None:
"""Initialize blueprint exception.""" """Initialize blueprint exception."""
super().__init__(domain, blueprint_name, "Blueprint already exists") super().__init__(domain, blueprint_name, "Blueprint already exists")
class BlueprintInUse(BlueprintWithNameException):
"""Error when a blueprint is in use."""
def __init__(self, domain: str, blueprint_name: str) -> None:
"""Initialize blueprint exception."""
super().__init__(domain, blueprint_name, "Blueprint in use")

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable
import logging import logging
import pathlib import pathlib
import shutil import shutil
@ -35,6 +36,7 @@ from .const import (
) )
from .errors import ( from .errors import (
BlueprintException, BlueprintException,
BlueprintInUse,
FailedToLoad, FailedToLoad,
FileAlreadyExists, FileAlreadyExists,
InvalidBlueprint, InvalidBlueprint,
@ -183,11 +185,13 @@ class DomainBlueprints:
hass: HomeAssistant, hass: HomeAssistant,
domain: str, domain: str,
logger: logging.Logger, logger: logging.Logger,
blueprint_in_use: Callable[[HomeAssistant, str], bool],
) -> None: ) -> None:
"""Initialize a domain blueprints instance.""" """Initialize a domain blueprints instance."""
self.hass = hass self.hass = hass
self.domain = domain self.domain = domain
self.logger = logger self.logger = logger
self._blueprint_in_use = blueprint_in_use
self._blueprints: dict[str, Blueprint | None] = {} self._blueprints: dict[str, Blueprint | None] = {}
self._load_lock = asyncio.Lock() self._load_lock = asyncio.Lock()
@ -302,6 +306,8 @@ class DomainBlueprints:
async def async_remove_blueprint(self, blueprint_path: str) -> None: async def async_remove_blueprint(self, blueprint_path: str) -> None:
"""Remove a blueprint file.""" """Remove a blueprint file."""
if self._blueprint_in_use(self.hass, blueprint_path):
raise BlueprintInUse(self.domain, blueprint_path)
path = self.blueprint_folder / blueprint_path path = self.blueprint_folder / blueprint_path
await self.hass.async_add_executor_job(path.unlink) await self.hass.async_add_executor_job(path.unlink)
self._blueprints[blueprint_path] = None self._blueprints[blueprint_path] = None

View File

@ -8,7 +8,7 @@ from typing import Any, cast
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
from homeassistant.components.blueprint import BlueprintInputs from homeassistant.components.blueprint import CONF_USE_BLUEPRINT, BlueprintInputs
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
ATTR_MODE, ATTR_MODE,
@ -18,6 +18,7 @@ from homeassistant.const import (
CONF_ICON, CONF_ICON,
CONF_MODE, CONF_MODE,
CONF_NAME, CONF_NAME,
CONF_PATH,
CONF_SEQUENCE, CONF_SEQUENCE,
CONF_VARIABLES, CONF_VARIABLES,
SERVICE_RELOAD, SERVICE_RELOAD,
@ -165,6 +166,21 @@ def areas_in_script(hass: HomeAssistant, entity_id: str) -> list[str]:
return list(script_entity.script.referenced_areas) return list(script_entity.script.referenced_areas)
@callback
def scripts_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str]:
"""Return all scripts that reference the blueprint."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
return [
script_entity.entity_id
for script_entity in component.entities
if script_entity.referenced_blueprint == blueprint_path
]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Load the scripts from the configuration.""" """Load the scripts from the configuration."""
hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass) hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass)
@ -372,6 +388,13 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
"""Return true if script is on.""" """Return true if script is on."""
return self.script.is_running return self.script.is_running
@property
def referenced_blueprint(self):
"""Return referenced blueprint or None."""
if self._blueprint_inputs is None:
return None
return self._blueprint_inputs[CONF_USE_BLUEPRINT][CONF_PATH]
@callback @callback
def async_change_listener(self): def async_change_listener(self):
"""Update state.""" """Update state."""

View File

@ -8,8 +8,15 @@ from .const import DOMAIN, LOGGER
DATA_BLUEPRINTS = "script_blueprints" DATA_BLUEPRINTS = "script_blueprints"
def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool:
"""Return True if any script references the blueprint."""
from . import scripts_with_blueprint # pylint: disable=import-outside-toplevel
return len(scripts_with_blueprint(hass, blueprint_path)) > 0
@singleton(DATA_BLUEPRINTS) @singleton(DATA_BLUEPRINTS)
@callback @callback
def async_get_blueprints(hass: HomeAssistant) -> DomainBlueprints: def async_get_blueprints(hass: HomeAssistant) -> DomainBlueprints:
"""Get script blueprints.""" """Get script blueprints."""
return DomainBlueprints(hass, DOMAIN, LOGGER) return DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use)

View File

@ -47,7 +47,9 @@ def blueprint_2():
@pytest.fixture @pytest.fixture
def domain_bps(hass): def domain_bps(hass):
"""Domain blueprints fixture.""" """Domain blueprints fixture."""
return models.DomainBlueprints(hass, "automation", logging.getLogger(__name__)) return models.DomainBlueprints(
hass, "automation", logging.getLogger(__name__), None
)
def test_blueprint_model_init(): def test_blueprint_model_init():

View File

@ -8,13 +8,26 @@ from homeassistant.setup import async_setup_component
from homeassistant.util.yaml import parse_yaml from homeassistant.util.yaml import parse_yaml
@pytest.fixture
def automation_config():
"""Automation config."""
return {}
@pytest.fixture
def script_config():
"""Script config."""
return {}
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
async def setup_bp(hass): async def setup_bp(hass, automation_config, script_config):
"""Fixture to set up the blueprint component.""" """Fixture to set up the blueprint component."""
assert await async_setup_component(hass, "blueprint", {}) assert await async_setup_component(hass, "blueprint", {})
# Trigger registration of automation blueprints # Trigger registration of automation and script blueprints
await async_setup_component(hass, "automation", {}) await async_setup_component(hass, "automation", automation_config)
await async_setup_component(hass, "script", script_config)
async def test_list_blueprints(hass, hass_ws_client): async def test_list_blueprints(hass, hass_ws_client):
@ -251,3 +264,89 @@ async def test_delete_non_exist_file_blueprint(hass, aioclient_mock, hass_ws_cli
assert msg["id"] == 9 assert msg["id"] == 9
assert not msg["success"] assert not msg["success"]
@pytest.mark.parametrize(
"automation_config",
(
{
"automation": {
"use_blueprint": {
"path": "test_event_service.yaml",
"input": {
"trigger_event": "blueprint_event",
"service_to_call": "test.automation",
"a_number": 5,
},
}
}
},
),
)
async def test_delete_blueprint_in_use_by_automation(
hass, aioclient_mock, hass_ws_client
):
"""Test deleting a blueprint which is in use."""
with patch("pathlib.Path.unlink", return_value=Mock()) as unlink_mock:
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 9,
"type": "blueprint/delete",
"path": "test_event_service.yaml",
"domain": "automation",
}
)
msg = await client.receive_json()
assert not unlink_mock.mock_calls
assert msg["id"] == 9
assert not msg["success"]
assert msg["error"] == {
"code": "unknown_error",
"message": "Blueprint in use",
}
@pytest.mark.parametrize(
"script_config",
(
{
"script": {
"test_script": {
"use_blueprint": {
"path": "test_service.yaml",
"input": {
"service_to_call": "test.automation",
},
}
}
}
},
),
)
async def test_delete_blueprint_in_use_by_script(hass, aioclient_mock, hass_ws_client):
"""Test deleting a blueprint which is in use."""
with patch("pathlib.Path.unlink", return_value=Mock()) as unlink_mock:
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 9,
"type": "blueprint/delete",
"path": "test_service.yaml",
"domain": "script",
}
)
msg = await client.receive_json()
assert not unlink_mock.mock_calls
assert msg["id"] == 9
assert not msg["success"]
assert msg["error"] == {
"code": "unknown_error",
"message": "Blueprint in use",
}

View File

@ -0,0 +1,8 @@
blueprint:
name: "Call service"
domain: script
input:
service_to_call:
sequence:
service: !input service_to_call
entity_id: light.kitchen