diff --git a/homeassistant/components/timer/reproduce_state.py b/homeassistant/components/timer/reproduce_state.py new file mode 100644 index 00000000000..c765ed7da9c --- /dev/null +++ b/homeassistant/components/timer/reproduce_state.py @@ -0,0 +1,70 @@ +"""Reproduce an Timer state.""" +import asyncio +import logging +from typing import Iterable, Optional + +from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.core import Context, State +from homeassistant.helpers.typing import HomeAssistantType + +from . import ( + ATTR_DURATION, + DOMAIN, + SERVICE_CANCEL, + SERVICE_PAUSE, + SERVICE_START, + STATUS_ACTIVE, + STATUS_IDLE, + STATUS_PAUSED, +) + +_LOGGER = logging.getLogger(__name__) + +VALID_STATES = {STATUS_IDLE, STATUS_ACTIVE, STATUS_PAUSED} + + +async def _async_reproduce_state( + hass: HomeAssistantType, state: State, context: Optional[Context] = None +) -> None: + """Reproduce a single state.""" + cur_state = hass.states.get(state.entity_id) + + if cur_state is None: + _LOGGER.warning("Unable to find entity %s", state.entity_id) + return + + if state.state not in VALID_STATES: + _LOGGER.warning( + "Invalid state specified for %s: %s", state.entity_id, state.state + ) + return + + # Return if we are already at the right state. + if cur_state.state == state.state and cur_state.attributes.get( + ATTR_DURATION + ) == state.attributes.get(ATTR_DURATION): + return + + service_data = {ATTR_ENTITY_ID: state.entity_id} + + if state.state == STATUS_ACTIVE: + service = SERVICE_START + if ATTR_DURATION in state.attributes: + service_data[ATTR_DURATION] = state.attributes[ATTR_DURATION] + elif state.state == STATUS_PAUSED: + service = SERVICE_PAUSE + elif state.state == STATUS_IDLE: + service = SERVICE_CANCEL + + await hass.services.async_call( + DOMAIN, service, service_data, context=context, blocking=True + ) + + +async def async_reproduce_states( + hass: HomeAssistantType, states: Iterable[State], context: Optional[Context] = None +) -> None: + """Reproduce Timer states.""" + await asyncio.gather( + *(_async_reproduce_state(hass, state, context) for state in states) + ) diff --git a/tests/components/timer/test_reproduce_state.py b/tests/components/timer/test_reproduce_state.py new file mode 100644 index 00000000000..5539d8610c3 --- /dev/null +++ b/tests/components/timer/test_reproduce_state.py @@ -0,0 +1,84 @@ +"""Test reproduce state for Timer.""" +from homeassistant.components.timer import ( + ATTR_DURATION, + SERVICE_CANCEL, + SERVICE_PAUSE, + SERVICE_START, + STATUS_ACTIVE, + STATUS_IDLE, + STATUS_PAUSED, +) +from homeassistant.core import State +from tests.common import async_mock_service + + +async def test_reproducing_states(hass, caplog): + """Test reproducing Timer states.""" + hass.states.async_set("timer.entity_idle", STATUS_IDLE, {}) + hass.states.async_set("timer.entity_paused", STATUS_PAUSED, {}) + hass.states.async_set("timer.entity_active", STATUS_ACTIVE, {}) + hass.states.async_set( + "timer.entity_active_attr", STATUS_ACTIVE, {ATTR_DURATION: "00:01:00"} + ) + + start_calls = async_mock_service(hass, "timer", SERVICE_START) + pause_calls = async_mock_service(hass, "timer", SERVICE_PAUSE) + cancel_calls = async_mock_service(hass, "timer", SERVICE_CANCEL) + + # These calls should do nothing as entities already in desired state + await hass.helpers.state.async_reproduce_state( + [ + State("timer.entity_idle", STATUS_IDLE), + State("timer.entity_paused", STATUS_PAUSED), + State("timer.entity_active", STATUS_ACTIVE), + State( + "timer.entity_active_attr", STATUS_ACTIVE, {ATTR_DURATION: "00:01:00"} + ), + ], + blocking=True, + ) + + assert len(start_calls) == 0 + assert len(pause_calls) == 0 + assert len(cancel_calls) == 0 + + # Test invalid state is handled + await hass.helpers.state.async_reproduce_state( + [State("timer.entity_idle", "not_supported")], blocking=True + ) + + assert "not_supported" in caplog.text + assert len(start_calls) == 0 + assert len(pause_calls) == 0 + assert len(cancel_calls) == 0 + + # Make sure correct services are called + await hass.helpers.state.async_reproduce_state( + [ + State("timer.entity_idle", STATUS_ACTIVE, {ATTR_DURATION: "00:01:00"}), + State("timer.entity_paused", STATUS_ACTIVE), + State("timer.entity_active", STATUS_IDLE), + State("timer.entity_active_attr", STATUS_PAUSED), + # Should not raise + State("timer.non_existing", "on"), + ], + blocking=True, + ) + + valid_start_calls = [ + {"entity_id": "timer.entity_idle", ATTR_DURATION: "00:01:00"}, + {"entity_id": "timer.entity_paused"}, + ] + assert len(start_calls) == 2 + for call in start_calls: + assert call.domain == "timer" + assert call.data in valid_start_calls + valid_start_calls.remove(call.data) + + assert len(pause_calls) == 1 + assert pause_calls[0].domain == "timer" + assert pause_calls[0].data == {"entity_id": "timer.entity_active_attr"} + + assert len(cancel_calls) == 1 + assert cancel_calls[0].domain == "timer" + assert cancel_calls[0].data == {"entity_id": "timer.entity_active"}