From f9ade788eb170f7912151b30d73ca5975b261ca1 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Fri, 16 Aug 2024 10:01:12 +0200 Subject: [PATCH] Do sanity check EntityPlatform.async_register_entity_service schema (#123058) * Do a sanity check of schema passed to EntityPlatform.async_register_entity_service * Only attempt to check schema of Schema * Handle All/Any wrapped in schema * Clarify comment * Apply suggestions from code review Co-authored-by: Robert Resch --------- Co-authored-by: Robert Resch --- homeassistant/helpers/entity_platform.py | 16 ++++++++++++++++ tests/helpers/test_entity_platform.py | 23 +++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index f3d5f5b076a..ec177fbf316 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -10,6 +10,8 @@ from functools import partial from logging import Logger, getLogger from typing import TYPE_CHECKING, Any, Protocol +import voluptuous as vol + from homeassistant import config_entries from homeassistant.const import ( ATTR_RESTORED, @@ -999,6 +1001,20 @@ class EntityPlatform: if schema is None or isinstance(schema, dict): schema = cv.make_entity_service_schema(schema) + # Do a sanity check to check this is a valid entity service schema, + # the check could be extended to require All/Any to have sub schema(s) + # with all entity service fields + elif ( + # Don't check All/Any + not isinstance(schema, (vol.All, vol.Any)) + # Don't check All/Any wrapped in schema + and not isinstance(schema.schema, (vol.All, vol.Any)) + and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS) + ): + raise HomeAssistantError( + "The schema does not include all required keys: " + f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}" + ) service_func: str | HassJob[..., Any] service_func = func if isinstance(func, str) else HassJob(func) diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index be8ba998481..50180ecd844 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -8,6 +8,7 @@ from typing import Any from unittest.mock import ANY, AsyncMock, Mock, patch import pytest +import voluptuous as vol from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, PERCENTAGE, EntityCategory from homeassistant.core import ( @@ -1788,6 +1789,28 @@ async def test_register_entity_service_none_schema( assert entity2 in entities +async def test_register_entity_service_non_entity_service_schema( + hass: HomeAssistant, +) -> None: + """Test attempting to register a service with an incomplete schema.""" + entity_platform = MockEntityPlatform( + hass, domain="mock_integration", platform_name="mock_platform", platform=None + ) + + with pytest.raises( + HomeAssistantError, + match=( + "The schema does not include all required keys: entity_id, device_id, area_id, " + "floor_id, label_id" + ), + ): + entity_platform.async_register_entity_service( + "hello", + vol.Schema({"some": str}), + Mock(), + ) + + @pytest.mark.parametrize("update_before_add", [True, False]) async def test_invalid_entity_id( hass: HomeAssistant, caplog: pytest.LogCaptureFixture, update_before_add: bool