diff --git a/tests/components/tradfri/common.py b/tests/components/tradfri/common.py index 81e21524eb0..9c636e14ee6 100644 --- a/tests/components/tradfri/common.py +++ b/tests/components/tradfri/common.py @@ -1,5 +1,12 @@ """Common tools used for the Tradfri test suite.""" +from copy import deepcopy +from typing import Any +from unittest.mock import Mock + +from pytradfri.device import Device + from homeassistant.components import tradfri +from homeassistant.core import HomeAssistant from . import GATEWAY_ID @@ -23,3 +30,47 @@ async def setup_integration(hass): await hass.async_block_till_done() return entry + + +def modify_state( + state: dict[str, Any], partial_state: dict[str, Any] +) -> dict[str, Any]: + """Modify a state with a partial state.""" + for key, value in partial_state.items(): + if isinstance(value, list): + for index, item in enumerate(value): + state[key][index] = modify_state(state[key][index], item) + elif isinstance(value, dict): + state[key] = modify_state(state[key], value) + else: + state[key] = value + + return state + + +async def trigger_observe_callback( + hass: HomeAssistant, + mock_gateway: Mock, + device: Device, + new_device_state: dict[str, Any] | None = None, +) -> None: + """Trigger the observe callback.""" + observe_command = next( + ( + command + for command in mock_gateway.mock_commands + if command.path == device.path and command.observe + ), + None, + ) + assert observe_command + + if new_device_state is not None: + mock_gateway.mock_responses.append(new_device_state) + + device_state = deepcopy(device.raw) + new_state = mock_gateway.mock_responses[-1] + device_state = modify_state(device_state, new_state) + observe_command.process_result(device_state) + + await hass.async_block_till_done() diff --git a/tests/components/tradfri/conftest.py b/tests/components/tradfri/conftest.py index e474765f60a..18d62c9a194 100644 --- a/tests/components/tradfri/conftest.py +++ b/tests/components/tradfri/conftest.py @@ -1,5 +1,8 @@ """Common tradfri test fixtures.""" -from unittest.mock import Mock, PropertyMock, patch +from __future__ import annotations + +from collections.abc import Generator +from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest @@ -43,6 +46,7 @@ def mock_gateway_fixture(): get_devices=get_devices, get_groups=get_groups, get_gateway_info=get_gateway_info, + mock_commands=[], mock_devices=[], mock_groups=[], mock_responses=[], @@ -62,13 +66,14 @@ def mock_api_fixture(mock_gateway): # Store the data for "real" command objects. if hasattr(command, "_data") and not isinstance(command, Mock): mock_gateway.mock_responses.append(command._data) + mock_gateway.mock_commands.append(command) return command return api @pytest.fixture -def mock_api_factory(mock_api): +def mock_api_factory(mock_api) -> Generator[MagicMock, None, None]: """Mock pytradfri api factory.""" with patch(f"{TRADFRI_PATH}.APIFactory", autospec=True) as factory: factory.init.return_value = factory.return_value diff --git a/tests/components/tradfri/fixtures/outlet.json b/tests/components/tradfri/fixtures/outlet.json new file mode 100644 index 00000000000..56caeb328a9 --- /dev/null +++ b/tests/components/tradfri/fixtures/outlet.json @@ -0,0 +1,18 @@ +{ + "9001": "Test", + "9002": 1536968250, + "9020": 1536968280, + "9003": 65548, + "9054": 0, + "5750": 3, + "9019": 1, + "9084": " 43 86 6e b5 6a df dc da d6 ce 9c 5a b4 63 a4 2a", + "3": { + "0": "IKEA of Sweden", + "1": "TRADFRI control outlet", + "3": "1.4.020", + "2": "", + "6": 1 + }, + "3312": [{ "9003": 0, "5850": 0 }] +} diff --git a/tests/components/tradfri/test_switch.py b/tests/components/tradfri/test_switch.py index 11903dc9a42..ab621f68579 100644 --- a/tests/components/tradfri/test_switch.py +++ b/tests/components/tradfri/test_switch.py @@ -1,160 +1,102 @@ """Tradfri switch (recognised as sockets in the IKEA ecosystem) platform tests.""" +from __future__ import annotations -from unittest.mock import MagicMock, Mock, PropertyMock, patch +import json +from typing import Any +from unittest.mock import MagicMock, Mock import pytest +from pytradfri.const import ATTR_REACHABLE_STATE from pytradfri.device import Device from pytradfri.device.socket import Socket -from pytradfri.device.socket_control import SocketControl -from .common import setup_integration +from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN +from homeassistant.components.tradfri.const import DOMAIN +from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE +from homeassistant.core import HomeAssistant + +from .common import setup_integration, trigger_observe_callback + +from tests.common import load_fixture -@pytest.fixture(autouse=True, scope="module") -def setup(request): - """Set up patches for pytradfri methods.""" - with patch( - "pytradfri.device.SocketControl.raw", - new_callable=PropertyMock, - return_value=[{"mock": "mock"}], - ), patch( - "pytradfri.device.SocketControl.sockets", - ): - yield +@pytest.fixture(scope="module") +def outlet() -> dict[str, Any]: + """Return an outlet response.""" + return json.loads(load_fixture("outlet.json", DOMAIN)) -def mock_switch(test_features=None, test_state=None, device_number=0): - """Mock a tradfri switch/socket.""" - if test_features is None: - test_features = {} - if test_state is None: - test_state = {} - mock_switch_data = Mock(**test_state) - - dev_info_mock = MagicMock() - dev_info_mock.manufacturer = "manufacturer" - dev_info_mock.model_number = "model" - dev_info_mock.firmware_version = "1.2.3" - _mock_switch = Mock( - id=f"mock-switch-id-{device_number}", - reachable=True, - observe=Mock(), - device_info=dev_info_mock, - has_light_control=False, - has_socket_control=True, - has_blind_control=False, - has_signal_repeater_control=False, - has_air_purifier_control=False, - ) - _mock_switch.name = f"tradfri_switch_{device_number}" - socket_control = SocketControl(_mock_switch) - - # Store the initial state. - setattr(socket_control, "sockets", [mock_switch_data]) - _mock_switch.socket_control = socket_control - return _mock_switch +@pytest.fixture +def socket(outlet: dict[str, Any]) -> Socket: + """Return socket.""" + device = Device(outlet) + socket_control = device.socket_control + assert socket_control + return socket_control.sockets[0] -async def test_switch(hass, mock_gateway, mock_api_factory): - """Test that switches are correctly added.""" - state = { - "state": True, - } - - mock_gateway.mock_devices.append(mock_switch(test_state=state)) - await setup_integration(hass) - - switch_1 = hass.states.get("switch.tradfri_switch_0") - assert switch_1 is not None - assert switch_1.state == "on" - - -async def test_switch_observed(hass, mock_gateway, mock_api_factory): - """Test that switches are correctly observed.""" - state = { - "state": True, - } - - switch = mock_switch(test_state=state) - mock_gateway.mock_devices.append(switch) - await setup_integration(hass) - assert len(switch.observe.mock_calls) > 0 - - -async def test_switch_available(hass, mock_gateway, mock_api_factory): +async def test_switch_available( + hass: HomeAssistant, + mock_gateway: Mock, + mock_api_factory: MagicMock, + socket: Socket, +) -> None: """Test switch available property.""" - - switch = mock_switch(test_state={"state": True}, device_number=1) - switch.reachable = True - - switch2 = mock_switch(test_state={"state": True}, device_number=2) - switch2.reachable = False - - mock_gateway.mock_devices.append(switch) - mock_gateway.mock_devices.append(switch2) + entity_id = "switch.test" + device = socket.device + mock_gateway.mock_devices.append(device) await setup_integration(hass) - assert hass.states.get("switch.tradfri_switch_1").state == "on" - assert hass.states.get("switch.tradfri_switch_2").state == "unavailable" + state = hass.states.get(entity_id) + assert state + assert state.state == STATE_OFF + + await trigger_observe_callback( + hass, mock_gateway, device, {ATTR_REACHABLE_STATE: 0} + ) + + state = hass.states.get(entity_id) + assert state + assert state.state == STATE_UNAVAILABLE @pytest.mark.parametrize( - "test_data, expected_result", + "service, expected_state", [ - ( - "turn_on", - "on", - ), - ("turn_off", "off"), + ("turn_on", STATE_ON), + ("turn_off", STATE_OFF), ], ) async def test_turn_on_off( - hass, - mock_gateway, - mock_api_factory, - test_data, - expected_result, -): + hass: HomeAssistant, + mock_gateway: Mock, + mock_api_factory: MagicMock, + socket: Socket, + service: str, + expected_state: str, +) -> None: """Test turning switch on/off.""" - # Note pytradfri style, not hass. Values not really important. - initial_state = { - "state": True, - } - - # Setup the gateway with a mock switch. - switch = mock_switch(test_state=initial_state, device_number=0) - mock_gateway.mock_devices.append(switch) + entity_id = "switch.test" + device = socket.device + mock_gateway.mock_devices.append(device) await setup_integration(hass) - # Use the turn_on/turn_off service call to change the switch state. + state = hass.states.get(entity_id) + assert state + assert state.state == STATE_OFF + await hass.services.async_call( - "switch", - test_data, + SWITCH_DOMAIN, + service, { - "entity_id": "switch.tradfri_switch_0", + "entity_id": entity_id, }, blocking=True, ) await hass.async_block_till_done() - # Check that the switch is observed. - mock_func = switch.observe - assert len(mock_func.mock_calls) > 0 - _, callkwargs = mock_func.call_args - assert "callback" in callkwargs - # Callback function to refresh switch state. - callback = callkwargs["callback"] + await trigger_observe_callback(hass, mock_gateway, device) - responses = mock_gateway.mock_responses - mock_gateway_response = responses[0] - - # Use the callback function to update the switch state. - dev = Device(mock_gateway_response) - switch_data = Socket(dev, 0) - switch.socket_control.sockets[0] = switch_data - callback(switch) - await hass.async_block_till_done() - - # Check that the state is correct. - state = hass.states.get("switch.tradfri_switch_0") - assert state.state == expected_result + state = hass.states.get(entity_id) + assert state + assert state.state == expected_state