diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index cea79c4fc8f..6f5e7c40d22 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -972,6 +972,8 @@ async def async_validate_condition_config( platform = await async_get_device_automation_platform( hass, config[CONF_DOMAIN], "condition" ) + if hasattr(platform, "async_validate_condition_config"): + return await platform.async_validate_condition_config(hass, config) # type: ignore return cast(ConfigType, platform.CONDITION_SCHEMA(config)) # type: ignore return config diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index ea3635888bb..156ceb8e612 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -256,7 +256,10 @@ async def async_validate_action_config( platform = await device_automation.async_get_device_automation_platform( hass, config[CONF_DOMAIN], "action" ) - config = platform.ACTION_SCHEMA(config) # type: ignore + if hasattr(platform, "async_validate_action_config"): + config = await platform.async_validate_action_config(hass, config) # type: ignore + else: + config = platform.ACTION_SCHEMA(config) # type: ignore elif action_type == cv.SCRIPT_ACTION_CHECK_CONDITION: if config[CONF_CONDITION] == "device": diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index 38a9367e36d..b1cbff83e33 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -1,13 +1,20 @@ """Test the condition helper.""" from datetime import datetime -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest from homeassistant.components import sun import homeassistant.components.automation as automation from homeassistant.components.sensor import DEVICE_CLASS_TIMESTAMP -from homeassistant.const import ATTR_DEVICE_CLASS, SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET +from homeassistant.const import ( + ATTR_DEVICE_CLASS, + CONF_CONDITION, + CONF_DEVICE_ID, + CONF_DOMAIN, + SUN_EVENT_SUNRISE, + SUN_EVENT_SUNSET, +) from homeassistant.exceptions import ConditionError, HomeAssistantError from homeassistant.helpers import condition, trace from homeassistant.helpers.template import Template @@ -2843,3 +2850,16 @@ async def test_trigger(hass): assert not test(hass, {"other_var": "123456"}) assert not test(hass, {"trigger": {"trigger_id": "123456"}}) assert test(hass, {"trigger": {"id": "123456"}}) + + +async def test_platform_async_validate_condition_config(hass): + """Test platform.async_validate_condition_config will be called if it exists.""" + config = {CONF_DEVICE_ID: "test", CONF_DOMAIN: "test", CONF_CONDITION: "device"} + platform = AsyncMock() + with patch( + "homeassistant.helpers.condition.async_get_device_automation_platform", + return_value=platform, + ): + platform.async_validate_condition_config.return_value = config + await condition.async_validate_condition_config(hass, config) + platform.async_validate_condition_config.assert_awaited() diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 0af8ff7d431..dfa5ce34ce7 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -6,7 +6,7 @@ from datetime import timedelta import logging from types import MappingProxyType from unittest import mock -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from async_timeout import timeout import pytest @@ -15,7 +15,12 @@ import voluptuous as vol # Otherwise can't test just this file (import order issue) from homeassistant import exceptions import homeassistant.components.scene as scene -from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON +from homeassistant.const import ( + ATTR_ENTITY_ID, + CONF_DEVICE_ID, + CONF_DOMAIN, + SERVICE_TURN_ON, +) from homeassistant.core import SERVICE_CALL_LIMIT, Context, CoreState, callback from homeassistant.exceptions import ConditionError, ServiceNotFound from homeassistant.helpers import config_validation as cv, script, trace @@ -3130,3 +3135,16 @@ async def test_breakpoints_2(hass): assert not script_obj.is_running assert script_obj.runs == 0 assert len(events) == 1 + + +async def test_platform_async_validate_action_config(hass): + """Test platform.async_validate_action_config will be called if it exists.""" + config = {CONF_DEVICE_ID: "test", CONF_DOMAIN: "test"} + platform = AsyncMock() + with patch( + "homeassistant.helpers.script.device_automation.async_get_device_automation_platform", + return_value=platform, + ): + platform.async_validate_action_config.return_value = config + await script.async_validate_action_config(hass, config) + platform.async_validate_action_config.assert_awaited()