diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py index fe49e1cf532..9d504d40de5 100644 --- a/homeassistant/components/automation/state.py +++ b/homeassistant/components/automation/state.py @@ -73,16 +73,13 @@ async def async_attach_trigger( from_s = event.data.get("old_state") to_s = event.data.get("new_state") + old_state = getattr(from_s, "state", None) + new_state = getattr(to_s, "state", None) if ( - (from_s is not None and not match_from_state(from_s.state)) - or (to_s is not None and not match_to_state(to_s.state)) - or ( - not match_all - and from_s is not None - and to_s is not None - and from_s.state == to_s.state - ) + not match_from_state(old_state) + or not match_to_state(new_state) + or (not match_all and old_state == new_state) ): return @@ -104,15 +101,6 @@ async def async_attach_trigger( ) ) - # Ignore changes to state attributes if from/to is in use - if ( - not match_all - and from_s is not None - and to_s is not None - and from_s.state == to_s.state - ): - return - if not time_delta: call_action() return diff --git a/tests/components/automation/test_state.py b/tests/components/automation/test_state.py index 13165da8488..9842818efab 100644 --- a/tests/components/automation/test_state.py +++ b/tests/components/automation/test_state.py @@ -519,28 +519,69 @@ async def test_if_fires_on_entity_change_with_for(hass, calls): assert 1 == len(calls) -async def test_if_fires_on_entity_removal(hass, calls): - """Test for firing on entity removal, when new_state is None.""" - context = Context() - hass.states.async_set("test.entity", "hello") - await hass.async_block_till_done() - +async def test_if_fires_on_entity_creation_and_removal(hass, calls): + """Test for firing on entity creation and removal, with to/from constraints.""" + # set automations for multiple combinations to/from assert await async_setup_component( hass, automation.DOMAIN, { - automation.DOMAIN: { - "trigger": {"platform": "state", "entity_id": "test.entity"}, - "action": {"service": "test.automation"}, - } + automation.DOMAIN: [ + { + "trigger": {"platform": "state", "entity_id": "test.entity_0"}, + "action": {"service": "test.automation"}, + }, + { + "trigger": { + "platform": "state", + "from": "hello", + "entity_id": "test.entity_1", + }, + "action": {"service": "test.automation"}, + }, + { + "trigger": { + "platform": "state", + "to": "world", + "entity_id": "test.entity_2", + }, + "action": {"service": "test.automation"}, + }, + ], }, ) await hass.async_block_till_done() - assert hass.states.async_remove("test.entity", context=context) + # use contexts to identify trigger entities + context_0 = Context() + context_1 = Context() + context_2 = Context() + + # automation with match_all triggers on creation + hass.states.async_set("test.entity_0", "any", context=context_0) await hass.async_block_till_done() assert len(calls) == 1 - assert calls[0].context.parent_id == context.id + assert calls[0].context.parent_id == context_0.id + + # create entities, trigger on test.entity_2 ('to' matches, no 'from') + hass.states.async_set("test.entity_1", "hello", context=context_1) + hass.states.async_set("test.entity_2", "world", context=context_2) + await hass.async_block_till_done() + assert len(calls) == 2 + assert calls[1].context.parent_id == context_2.id + + # removal of both, trigger on test.entity_1 ('from' matches, no 'to') + assert hass.states.async_remove("test.entity_1", context=context_1) + assert hass.states.async_remove("test.entity_2", context=context_2) + await hass.async_block_till_done() + assert len(calls) == 3 + assert calls[2].context.parent_id == context_1.id + + # automation with match_all triggers on removal + assert hass.states.async_remove("test.entity_0", context=context_0) + await hass.async_block_till_done() + assert len(calls) == 4 + assert calls[3].context.parent_id == context_0.id async def test_if_fires_on_for_condition(hass, calls):