From 7e4be921a81dd0f0415c47f4ceac12189a89979f Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Sat, 10 Apr 2021 08:19:16 +0200 Subject: [PATCH] Add helper to get an entity's supported features (#48825) * Add helper to check entity's supported features * Move get_supported_features to helpers/entity.py, add tests * Fix error handling and improve tests --- .../components/light/device_action.py | 32 +--- homeassistant/helpers/entity.py | 18 ++ tests/components/light/test_device_action.py | 166 +++++++++++------- tests/helpers/test_entity.py | 30 +++- 4 files changed, 158 insertions(+), 88 deletions(-) diff --git a/homeassistant/components/light/device_action.py b/homeassistant/components/light/device_action.py index 4c37647f168..9cdb5764d70 100644 --- a/homeassistant/components/light/device_action.py +++ b/homeassistant/components/light/device_action.py @@ -11,15 +11,10 @@ from homeassistant.components.light import ( VALID_BRIGHTNESS_PCT, VALID_FLASH, ) -from homeassistant.const import ( - ATTR_ENTITY_ID, - ATTR_SUPPORTED_FEATURES, - CONF_DOMAIN, - CONF_TYPE, - SERVICE_TURN_ON, -) -from homeassistant.core import Context, HomeAssistant +from homeassistant.const import ATTR_ENTITY_ID, CONF_DOMAIN, CONF_TYPE, SERVICE_TURN_ON +from homeassistant.core import Context, HomeAssistant, HomeAssistantError from homeassistant.helpers import config_validation as cv, entity_registry +from homeassistant.helpers.entity import get_supported_features from homeassistant.helpers.typing import ConfigType, TemplateVarsType from . import ATTR_BRIGHTNESS_PCT, ATTR_BRIGHTNESS_STEP_PCT, DOMAIN, SUPPORT_BRIGHTNESS @@ -88,12 +83,7 @@ async def async_get_actions(hass: HomeAssistant, device_id: str) -> list[dict]: if entry.domain != DOMAIN: continue - state = hass.states.get(entry.entity_id) - - if state: - supported_features = state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) - else: - supported_features = entry.supported_features + supported_features = get_supported_features(hass, entry.entity_id) if supported_features & SUPPORT_BRIGHTNESS: actions.extend( @@ -133,16 +123,10 @@ async def async_get_action_capabilities(hass: HomeAssistant, config: dict) -> di if config[CONF_TYPE] != toggle_entity.CONF_TURN_ON: return {} - registry = await entity_registry.async_get_registry(hass) - entry = registry.async_get(config[ATTR_ENTITY_ID]) - state = hass.states.get(config[ATTR_ENTITY_ID]) - - supported_features = 0 - - if state: - supported_features = state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) - elif entry: - supported_features = entry.supported_features + try: + supported_features = get_supported_features(hass, config[ATTR_ENTITY_ID]) + except HomeAssistantError: + supported_features = 0 extra_fields = {} diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 0074c0ba5e8..f30832479c2 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -29,6 +29,7 @@ from homeassistant.const import ( ) from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, NoEntitySpecifiedError +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.entity_platform import EntityPlatform from homeassistant.helpers.entity_registry import RegistryEntry from homeassistant.helpers.event import Event, async_track_entity_registry_updated_event @@ -86,6 +87,23 @@ def async_generate_entity_id( return test_string +def get_supported_features(hass: HomeAssistant, entity_id: str) -> int: + """Get supported features for an entity. + + First try the statemachine, then entity registry. + """ + state = hass.states.get(entity_id) + if state: + return state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) + + entity_registry = er.async_get(hass) + entry = entity_registry.async_get(entity_id) + if not entry: + raise HomeAssistantError(f"Unknown entity {entity_id}") + + return entry.supported_features or 0 + + class Entity(ABC): """An abstract class for Home Assistant entities.""" diff --git a/tests/components/light/test_device_action.py b/tests/components/light/test_device_action.py index 4760dfd1c53..5d6ca2f4a2c 100644 --- a/tests/components/light/test_device_action.py +++ b/tests/components/light/test_device_action.py @@ -107,13 +107,13 @@ async def test_get_action_capabilities(hass, device_reg, entity_reg): config_entry_id=config_entry.entry_id, connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) - entity_reg.async_get_or_create( + # Test with entity without optional capabilities + entity_id = entity_reg.async_get_or_create( DOMAIN, "test", "5678", device_id=device_entry.id, - ) - + ).entity_id actions = await async_get_device_automations(hass, "action", device_entry.id) assert len(actions) == 3 for action in actions: @@ -122,8 +122,96 @@ async def test_get_action_capabilities(hass, device_reg, entity_reg): ) assert capabilities == {"extra_fields": []} + # Test without entity + entity_reg.async_remove(entity_id) + for action in actions: + capabilities = await async_get_device_automation_capabilities( + hass, "action", action + ) + assert capabilities == {"extra_fields": []} -async def test_get_action_capabilities_brightness(hass, device_reg, entity_reg): + +@pytest.mark.parametrize( + "set_state,num_actions,supported_features_reg,supported_features_state,expected_capabilities", + [ + ( + False, + 5, + SUPPORT_BRIGHTNESS, + 0, + { + "turn_on": [ + { + "name": "brightness_pct", + "optional": True, + "type": "float", + "valueMax": 100, + "valueMin": 0, + } + ] + }, + ), + ( + True, + 5, + 0, + SUPPORT_BRIGHTNESS, + { + "turn_on": [ + { + "name": "brightness_pct", + "optional": True, + "type": "float", + "valueMax": 100, + "valueMin": 0, + } + ] + }, + ), + ( + False, + 4, + SUPPORT_FLASH, + 0, + { + "turn_on": [ + { + "name": "flash", + "optional": True, + "type": "select", + "options": [("short", "short"), ("long", "long")], + } + ] + }, + ), + ( + True, + 4, + 0, + SUPPORT_FLASH, + { + "turn_on": [ + { + "name": "flash", + "optional": True, + "type": "select", + "options": [("short", "short"), ("long", "long")], + } + ] + }, + ), + ], +) +async def test_get_action_capabilities_features( + hass, + device_reg, + entity_reg, + set_state, + num_actions, + supported_features_reg, + supported_features_state, + expected_capabilities, +): """Test we get the expected capabilities from a light action.""" config_entry = MockConfigEntry(domain="test", data={}) config_entry.add_to_hass(hass) @@ -131,74 +219,26 @@ async def test_get_action_capabilities_brightness(hass, device_reg, entity_reg): config_entry_id=config_entry.entry_id, connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, ) - entity_reg.async_get_or_create( + entity_id = entity_reg.async_get_or_create( DOMAIN, "test", "5678", device_id=device_entry.id, - supported_features=SUPPORT_BRIGHTNESS, - ) + supported_features=supported_features_reg, + ).entity_id + if set_state: + hass.states.async_set( + entity_id, None, {"supported_features": supported_features_state} + ) - expected_capabilities = { - "extra_fields": [ - { - "name": "brightness_pct", - "optional": True, - "type": "float", - "valueMax": 100, - "valueMin": 0, - } - ] - } actions = await async_get_device_automations(hass, "action", device_entry.id) - assert len(actions) == 5 + assert len(actions) == num_actions for action in actions: capabilities = await async_get_device_automation_capabilities( hass, "action", action ) - if action["type"] == "turn_on": - assert capabilities == expected_capabilities - else: - assert capabilities == {"extra_fields": []} - - -async def test_get_action_capabilities_flash(hass, device_reg, entity_reg): - """Test we get the expected capabilities from a light action.""" - config_entry = MockConfigEntry(domain="test", data={}) - config_entry.add_to_hass(hass) - device_entry = device_reg.async_get_or_create( - config_entry_id=config_entry.entry_id, - connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, - ) - entity_reg.async_get_or_create( - DOMAIN, - "test", - "5678", - device_id=device_entry.id, - supported_features=SUPPORT_FLASH, - ) - - expected_capabilities = { - "extra_fields": [ - { - "name": "flash", - "optional": True, - "type": "select", - "options": [("short", "short"), ("long", "long")], - } - ] - } - - actions = await async_get_device_automations(hass, "action", device_entry.id) - assert len(actions) == 4 - for action in actions: - capabilities = await async_get_device_automation_capabilities( - hass, "action", action - ) - if action["type"] == "turn_on": - assert capabilities == expected_capabilities - else: - assert capabilities == {"extra_fields": []} + expected = {"extra_fields": expected_capabilities.get(action["type"], [])} + assert capabilities == expected async def test_action(hass, calls): @@ -209,7 +249,7 @@ async def test_action(hass, calls): assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}}) await hass.async_block_till_done() - ent1, ent2, ent3 = platform.ENTITIES + ent1 = platform.ENTITIES[0] assert await async_setup_component( hass, diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index b8d0fc7dc9c..6eeabb59eba 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -8,7 +8,7 @@ from unittest.mock import MagicMock, PropertyMock, patch import pytest from homeassistant.const import ATTR_DEVICE_CLASS, STATE_UNAVAILABLE, STATE_UNKNOWN -from homeassistant.core import Context +from homeassistant.core import Context, HomeAssistantError from homeassistant.helpers import entity, entity_registry from tests.common import ( @@ -744,3 +744,31 @@ async def test_removing_entity_unavailable(hass): state = hass.states.get("hello.world") assert state is not None assert state.state == STATE_UNAVAILABLE + + +async def test_get_supported_features_entity_registry(hass): + """Test get_supported_features falls back to entity registry.""" + entity_reg = mock_registry(hass) + entity_id = entity_reg.async_get_or_create( + "hello", "world", "5678", supported_features=456 + ).entity_id + assert entity.get_supported_features(hass, entity_id) == 456 + + +async def test_get_supported_features_prioritize_state(hass): + """Test get_supported_features gives priority to state.""" + entity_reg = mock_registry(hass) + entity_id = entity_reg.async_get_or_create( + "hello", "world", "5678", supported_features=456 + ).entity_id + assert entity.get_supported_features(hass, entity_id) == 456 + + hass.states.async_set(entity_id, None, {"supported_features": 123}) + + assert entity.get_supported_features(hass, entity_id) == 123 + + +async def test_get_supported_features_raises_on_unknown(hass): + """Test get_supported_features raises on unknown entity_id.""" + with pytest.raises(HomeAssistantError): + entity.get_supported_features(hass, "hello.world")