diff --git a/homeassistant/components/binary_sensor/device_condition.py b/homeassistant/components/binary_sensor/device_condition.py index 8c506634200..eed5c3f5896 100644 --- a/homeassistant/components/binary_sensor/device_condition.py +++ b/homeassistant/components/binary_sensor/device_condition.py @@ -4,9 +4,10 @@ from __future__ import annotations import voluptuous as vol from homeassistant.components.device_automation.const import CONF_IS_OFF, CONF_IS_ON -from homeassistant.const import ATTR_DEVICE_CLASS, CONF_ENTITY_ID, CONF_FOR, CONF_TYPE +from homeassistant.const import CONF_ENTITY_ID, CONF_FOR, CONF_TYPE from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import condition, config_validation as cv +from homeassistant.helpers.entity import get_device_class from homeassistant.helpers.entity_registry import ( async_entries_for_device, async_get_registry, @@ -216,10 +217,7 @@ async def async_get_conditions( ] for entry in entries: - device_class = DEVICE_CLASS_NONE - state = hass.states.get(entry.entity_id) - if state and ATTR_DEVICE_CLASS in state.attributes: - device_class = state.attributes[ATTR_DEVICE_CLASS] + device_class = get_device_class(hass, entry.entity_id) or DEVICE_CLASS_NONE templates = ENTITY_CONDITIONS.get( device_class, ENTITY_CONDITIONS[DEVICE_CLASS_NONE] diff --git a/tests/components/binary_sensor/test_device_condition.py b/tests/components/binary_sensor/test_device_condition.py index 5d8673825fc..3d1b694c7ce 100644 --- a/tests/components/binary_sensor/test_device_condition.py +++ b/tests/components/binary_sensor/test_device_condition.py @@ -78,6 +78,41 @@ async def test_get_conditions(hass, device_reg, entity_reg, enable_custom_integr assert conditions == expected_conditions +async def test_get_conditions_no_state(hass, device_reg, entity_reg): + """Test we get the expected conditions from a binary_sensor.""" + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + entity_ids = {} + for device_class in DEVICE_CLASSES: + entity_ids[device_class] = entity_reg.async_get_or_create( + DOMAIN, + "test", + f"5678_{device_class}", + device_id=device_entry.id, + device_class=device_class, + ).entity_id + + await hass.async_block_till_done() + + expected_conditions = [ + { + "condition": "device", + "domain": DOMAIN, + "type": condition["type"], + "device_id": device_entry.id, + "entity_id": entity_ids[device_class], + } + for device_class in DEVICE_CLASSES + for condition in ENTITY_CONDITIONS[device_class] + ] + conditions = await async_get_device_automations(hass, "condition", device_entry.id) + assert conditions == expected_conditions + + async def test_get_condition_capabilities(hass, device_reg, entity_reg): """Test we get the expected capabilities from a binary_sensor condition.""" config_entry = MockConfigEntry(domain="test", data={}) diff --git a/tests/components/binary_sensor/test_device_trigger.py b/tests/components/binary_sensor/test_device_trigger.py index 1dbed7d19e1..8bd80be6524 100644 --- a/tests/components/binary_sensor/test_device_trigger.py +++ b/tests/components/binary_sensor/test_device_trigger.py @@ -78,9 +78,7 @@ async def test_get_triggers(hass, device_reg, entity_reg, enable_custom_integrat assert triggers == expected_triggers -async def test_get_triggers_no_state( - hass, device_reg, entity_reg, enable_custom_integrations -): +async def test_get_triggers_no_state(hass, device_reg, entity_reg): """Test we get the expected triggers from a binary_sensor.""" platform = getattr(hass.components, f"test.{DOMAIN}") platform.init() @@ -96,7 +94,7 @@ async def test_get_triggers_no_state( entity_ids[device_class] = entity_reg.async_get_or_create( DOMAIN, "test", - platform.ENTITIES[device_class].unique_id, + f"5678_{device_class}", device_id=device_entry.id, device_class=device_class, ).entity_id