diff --git a/homeassistant/components/device_automation/__init__.py b/homeassistant/components/device_automation/__init__.py index 0be1c3eb1dd..80e64033295 100644 --- a/homeassistant/components/device_automation/__init__.py +++ b/homeassistant/components/device_automation/__init__.py @@ -156,7 +156,11 @@ async def _async_get_device_automation_capabilities(hass, automation_type, autom # The device automation has no capabilities return {} - capabilities = await getattr(platform, function_name)(hass, automation) + try: + capabilities = await getattr(platform, function_name)(hass, automation) + except InvalidDeviceAutomationConfig: + return {} + capabilities = capabilities.copy() extra_fields = capabilities.get("extra_fields") diff --git a/homeassistant/components/sensor/device_condition.py b/homeassistant/components/sensor/device_condition.py index 26479807991..259fb5dbab9 100644 --- a/homeassistant/components/sensor/device_condition.py +++ b/homeassistant/components/sensor/device_condition.py @@ -2,6 +2,9 @@ from typing import Dict, List import voluptuous as vol +from homeassistant.components.device_automation.exceptions import ( + InvalidDeviceAutomationConfig, +) from homeassistant.core import HomeAssistant from homeassistant.const import ( ATTR_DEVICE_CLASS, @@ -141,3 +144,27 @@ def async_condition_from_config( numeric_state_config[condition.CONF_BELOW] = config[CONF_BELOW] return condition.async_numeric_state_from_config(numeric_state_config) + + +async def async_get_condition_capabilities(hass, config): + """List condition capabilities.""" + state = hass.states.get(config[CONF_ENTITY_ID]) + unit_of_measurement = ( + state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if state else None + ) + + if not state or not unit_of_measurement: + raise InvalidDeviceAutomationConfig + + return { + "extra_fields": vol.Schema( + { + vol.Optional( + CONF_ABOVE, description={"suffix": unit_of_measurement} + ): vol.Coerce(float), + vol.Optional( + CONF_BELOW, description={"suffix": unit_of_measurement} + ): vol.Coerce(float), + } + ) + } diff --git a/homeassistant/components/sensor/device_trigger.py b/homeassistant/components/sensor/device_trigger.py index b462124165a..73e55340da9 100644 --- a/homeassistant/components/sensor/device_trigger.py +++ b/homeassistant/components/sensor/device_trigger.py @@ -3,6 +3,9 @@ import voluptuous as vol import homeassistant.components.automation.numeric_state as numeric_state_automation from homeassistant.components.device_automation import TRIGGER_BASE_SCHEMA +from homeassistant.components.device_automation.exceptions import ( + InvalidDeviceAutomationConfig, +) from homeassistant.const import ( ATTR_DEVICE_CLASS, ATTR_UNIT_OF_MEASUREMENT, @@ -146,9 +149,12 @@ async def async_get_trigger_capabilities(hass, config): """List trigger capabilities.""" state = hass.states.get(config[CONF_ENTITY_ID]) unit_of_measurement = ( - state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if state else "" + state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if state else None ) + if not state or not unit_of_measurement: + raise InvalidDeviceAutomationConfig + return { "extra_fields": vol.Schema( { diff --git a/tests/components/sensor/test_device_condition.py b/tests/components/sensor/test_device_condition.py index e28e487f4ef..f3ff15c3ad9 100644 --- a/tests/components/sensor/test_device_condition.py +++ b/tests/components/sensor/test_device_condition.py @@ -14,6 +14,7 @@ from tests.common import ( mock_device_registry, mock_registry, async_get_device_automations, + async_get_device_automation_capabilities, ) from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES @@ -73,6 +74,86 @@ async def test_get_conditions(hass, device_reg, entity_reg): assert conditions == expected_conditions +async def test_get_condition_capabilities(hass, device_reg, entity_reg): + """Test we get the expected capabilities from a sensor condition.""" + platform = getattr(hass.components, f"test.{DOMAIN}") + platform.init() + + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + entity_reg.async_get_or_create( + DOMAIN, + "test", + platform.ENTITIES["battery"].unique_id, + device_id=device_entry.id, + ) + + assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) + + expected_capabilities = { + "extra_fields": [ + { + "description": {"suffix": "%"}, + "name": "above", + "optional": True, + "type": "float", + }, + { + "description": {"suffix": "%"}, + "name": "below", + "optional": True, + "type": "float", + }, + ] + } + conditions = await async_get_device_automations(hass, "condition", device_entry.id) + assert len(conditions) == 1 + for condition in conditions: + capabilities = await async_get_device_automation_capabilities( + hass, "condition", condition + ) + assert capabilities == expected_capabilities + + +async def test_get_condition_capabilities_none(hass, device_reg, entity_reg): + """Test we get the expected capabilities from a sensor condition.""" + platform = getattr(hass.components, f"test.{DOMAIN}") + platform.init() + + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + + assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) + + conditions = [ + { + "condition": "device", + "device_id": "8770c43885354d5fa27604db6817f63f", + "domain": "sensor", + "entity_id": "sensor.beer", + "type": "is_battery_level", + }, + { + "condition": "device", + "device_id": "8770c43885354d5fa27604db6817f63f", + "domain": "sensor", + "entity_id": platform.ENTITIES["none"].entity_id, + "type": "is_battery_level", + }, + ] + + expected_capabilities = {} + for condition in conditions: + capabilities = await async_get_device_automation_capabilities( + hass, "condition", condition + ) + assert capabilities == expected_capabilities + + async def test_if_state_not_above_below(hass, calls, caplog): """Test for bad value conditions.""" platform = getattr(hass.components, f"test.{DOMAIN}") diff --git a/tests/components/sensor/test_device_trigger.py b/tests/components/sensor/test_device_trigger.py index a21839fcebc..b7a921fff18 100644 --- a/tests/components/sensor/test_device_trigger.py +++ b/tests/components/sensor/test_device_trigger.py @@ -124,6 +124,41 @@ async def test_get_trigger_capabilities(hass, device_reg, entity_reg): assert capabilities == expected_capabilities +async def test_get_trigger_capabilities_none(hass, device_reg, entity_reg): + """Test we get the expected capabilities from a sensor trigger.""" + platform = getattr(hass.components, f"test.{DOMAIN}") + platform.init() + + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + + assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) + + triggers = [ + { + "platform": "device", + "device_id": "8770c43885354d5fa27604db6817f63f", + "domain": "sensor", + "entity_id": "sensor.beer", + "type": "is_battery_level", + }, + { + "platform": "device", + "device_id": "8770c43885354d5fa27604db6817f63f", + "domain": "sensor", + "entity_id": platform.ENTITIES["none"].entity_id, + "type": "is_battery_level", + }, + ] + + expected_capabilities = {} + for trigger in triggers: + capabilities = await async_get_device_automation_capabilities( + hass, "trigger", trigger + ) + assert capabilities == expected_capabilities + + async def test_if_fires_not_on_above_below(hass, calls, caplog): """Test for value triggers firing.""" platform = getattr(hass.components, f"test.{DOMAIN}")