diff --git a/homeassistant/components/shelly/device_trigger.py b/homeassistant/components/shelly/device_trigger.py index b7cf1120949..97938040543 100644 --- a/homeassistant/components/shelly/device_trigger.py +++ b/homeassistant/components/shelly/device_trigger.py @@ -27,6 +27,7 @@ from .const import ( DOMAIN, EVENT_SHELLY_CLICK, INPUTS_EVENTS_SUBTYPES, + SHBTN_1_INPUTS_EVENTS_TYPES, SUPPORTED_INPUTS_EVENTS_TYPES, ) from .utils import get_device_wrapper, get_input_triggers @@ -45,7 +46,7 @@ async def async_validate_trigger_config(hass, config): # if device is available verify parameters against device capabilities wrapper = get_device_wrapper(hass, config[CONF_DEVICE_ID]) - if not wrapper: + if not wrapper or not wrapper.device.initialized: return config trigger = (config[CONF_TYPE], config[CONF_SUBTYPE]) @@ -68,6 +69,19 @@ async def async_get_triggers(hass: HomeAssistant, device_id: str) -> list[dict]: if not wrapper: raise InvalidDeviceAutomationConfig(f"Device not found: {device_id}") + if wrapper.model in ("SHBTN-1", "SHBTN-2"): + for trigger in SHBTN_1_INPUTS_EVENTS_TYPES: + triggers.append( + { + CONF_PLATFORM: "device", + CONF_DEVICE_ID: device_id, + CONF_DOMAIN: DOMAIN, + CONF_TYPE: trigger, + CONF_SUBTYPE: "button", + } + ) + return triggers + for block in wrapper.device.blocks: input_triggers = get_input_triggers(wrapper.device, block) diff --git a/tests/components/shelly/test_device_trigger.py b/tests/components/shelly/test_device_trigger.py index a725f5a1f30..bedf4abc0f2 100644 --- a/tests/components/shelly/test_device_trigger.py +++ b/tests/components/shelly/test_device_trigger.py @@ -1,4 +1,6 @@ """The tests for Shelly device triggers.""" +from unittest.mock import AsyncMock, Mock + import pytest from homeassistant import setup @@ -6,10 +8,13 @@ from homeassistant.components import automation from homeassistant.components.device_automation.exceptions import ( InvalidDeviceAutomationConfig, ) +from homeassistant.components.shelly import ShellyDeviceWrapper from homeassistant.components.shelly.const import ( ATTR_CHANNEL, ATTR_CLICK_TYPE, + COAP, CONF_SUBTYPE, + DATA_CONFIG_ENTRY, DOMAIN, EVENT_SHELLY_CLICK, ) @@ -52,6 +57,71 @@ async def test_get_triggers(hass, coap_wrapper): assert_lists_same(triggers, expected_triggers) +async def test_get_triggers_button(hass): + """Test we get the expected triggers from a shelly button.""" + await async_setup_component(hass, "shelly", {}) + + config_entry = MockConfigEntry( + domain=DOMAIN, + data={"sleep_period": 43200, "model": "SHBTN-1"}, + unique_id="12345678", + ) + config_entry.add_to_hass(hass) + + device = Mock( + blocks=None, + settings=None, + shelly=None, + update=AsyncMock(), + initialized=False, + ) + + hass.data[DOMAIN] = {DATA_CONFIG_ENTRY: {}} + hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id] = {} + coap_wrapper = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id][ + COAP + ] = ShellyDeviceWrapper(hass, config_entry, device) + + await coap_wrapper.async_setup() + + expected_triggers = [ + { + CONF_PLATFORM: "device", + CONF_DEVICE_ID: coap_wrapper.device_id, + CONF_DOMAIN: DOMAIN, + CONF_TYPE: "single", + CONF_SUBTYPE: "button", + }, + { + CONF_PLATFORM: "device", + CONF_DEVICE_ID: coap_wrapper.device_id, + CONF_DOMAIN: DOMAIN, + CONF_TYPE: "double", + CONF_SUBTYPE: "button", + }, + { + CONF_PLATFORM: "device", + CONF_DEVICE_ID: coap_wrapper.device_id, + CONF_DOMAIN: DOMAIN, + CONF_TYPE: "triple", + CONF_SUBTYPE: "button", + }, + { + CONF_PLATFORM: "device", + CONF_DEVICE_ID: coap_wrapper.device_id, + CONF_DOMAIN: DOMAIN, + CONF_TYPE: "long", + CONF_SUBTYPE: "button", + }, + ] + + triggers = await async_get_device_automations( + hass, "trigger", coap_wrapper.device_id + ) + + assert_lists_same(triggers, expected_triggers) + + async def test_get_triggers_for_invalid_device_id(hass, device_reg, coap_wrapper): """Test error raised for invalid shelly device_id.""" assert coap_wrapper