diff --git a/tests/components/shelly/__init__.py b/tests/components/shelly/__init__.py index 26040e13557..daf96db13d3 100644 --- a/tests/components/shelly/__init__.py +++ b/tests/components/shelly/__init__.py @@ -74,7 +74,7 @@ def mutate_rpc_device_status( def inject_rpc_device_event( monkeypatch: pytest.MonkeyPatch, mock_rpc_device: Mock, - event: dict[str, dict[str, Any]], + event: Mapping[str, list[dict[str, Any]] | float], ) -> None: """Inject event for rpc device.""" monkeypatch.setattr(mock_rpc_device, "event", event) @@ -121,6 +121,13 @@ def register_entity( return f"{domain}.{object_id}" +def get_entity_state(hass: HomeAssistant, entity_id: str) -> str: + """Return entity state.""" + entity = hass.states.get(entity_id) + assert entity + return entity.state + + def register_device(device_reg, config_entry: ConfigEntry): """Register Shelly device.""" device_reg.async_get_or_create( diff --git a/tests/components/shelly/conftest.py b/tests/components/shelly/conftest.py index 8a863a852f5..af373f33c23 100644 --- a/tests/components/shelly/conftest.py +++ b/tests/components/shelly/conftest.py @@ -12,6 +12,7 @@ from homeassistant.components.shelly.const import ( EVENT_SHELLY_CLICK, REST_SENSORS_UPDATE_INTERVAL, ) +from homeassistant.core import HomeAssistant from . import MOCK_MAC @@ -252,19 +253,19 @@ def mock_ws_server(): @pytest.fixture -def device_reg(hass): +def device_reg(hass: HomeAssistant): """Return an empty, loaded, registry.""" return mock_device_registry(hass) @pytest.fixture -def calls(hass): +def calls(hass: HomeAssistant): """Track calls to a mock service.""" return async_mock_service(hass, "test", "automation") @pytest.fixture -def events(hass): +def events(hass: HomeAssistant): """Yield caught shelly_click events.""" return async_capture_events(hass, EVENT_SHELLY_CLICK) diff --git a/tests/components/shelly/test_coordinator.py b/tests/components/shelly/test_coordinator.py index 27aa8710621..8e288ba1687 100644 --- a/tests/components/shelly/test_coordinator.py +++ b/tests/components/shelly/test_coordinator.py @@ -33,6 +33,7 @@ import homeassistant.helpers.issue_registry as ir from . import ( MOCK_MAC, + get_entity_state, init_integration, inject_rpc_device_event, mock_polling_rpc_update, @@ -196,14 +197,14 @@ async def test_block_polling_connection_error( ) await init_integration(hass, 1) - assert hass.states.get("switch.test_name_channel_1").state == STATE_ON + assert get_entity_state(hass, "switch.test_name_channel_1") == STATE_ON # Move time to generate polling freezer.tick(timedelta(seconds=UPDATE_PERIOD_MULTIPLIER * 15)) async_fire_time_changed(hass) await hass.async_block_till_done() - assert hass.states.get("switch.test_name_channel_1").state == STATE_UNAVAILABLE + assert get_entity_state(hass, "switch.test_name_channel_1") == STATE_UNAVAILABLE async def test_block_rest_update_connection_error( @@ -216,7 +217,7 @@ async def test_block_rest_update_connection_error( await init_integration(hass, 1) await mock_rest_update(hass, freezer) - assert hass.states.get(entity_id).state == STATE_ON + assert get_entity_state(hass, entity_id) == STATE_ON monkeypatch.setattr( mock_block_device, @@ -225,7 +226,7 @@ async def test_block_rest_update_connection_error( ) await mock_rest_update(hass, freezer) - assert hass.states.get(entity_id).state == STATE_UNAVAILABLE + assert get_entity_state(hass, entity_id) == STATE_UNAVAILABLE async def test_block_sleeping_device_no_periodic_updates( @@ -239,14 +240,14 @@ async def test_block_sleeping_device_no_periodic_updates( mock_block_device.mock_update() await hass.async_block_till_done() - assert hass.states.get(entity_id).state == "22.1" + assert get_entity_state(hass, entity_id) == "22.1" # Move time to generate polling freezer.tick(timedelta(seconds=UPDATE_PERIOD_MULTIPLIER * 1000)) async_fire_time_changed(hass) await hass.async_block_till_done() - assert hass.states.get(entity_id).state == STATE_UNAVAILABLE + assert get_entity_state(hass, entity_id) == STATE_UNAVAILABLE async def test_block_device_push_updates_failure( @@ -496,14 +497,14 @@ async def test_rpc_sleeping_device_no_periodic_updates( mock_rpc_device.mock_update() await hass.async_block_till_done() - assert hass.states.get(entity_id).state == "22.9" + assert get_entity_state(hass, entity_id) == "22.9" # Move time to generate polling freezer.tick(timedelta(seconds=SLEEP_PERIOD_MULTIPLIER * 1000)) async_fire_time_changed(hass) await hass.async_block_till_done() - assert hass.states.get(entity_id).state == STATE_UNAVAILABLE + assert get_entity_state(hass, entity_id) == STATE_UNAVAILABLE async def test_rpc_reconnect_auth_error( @@ -581,7 +582,7 @@ async def test_rpc_reconnect_error( """Test RPC reconnect error.""" await init_integration(hass, 2) - assert hass.states.get("switch.test_switch_0").state == STATE_ON + assert get_entity_state(hass, "switch.test_switch_0") == STATE_ON monkeypatch.setattr(mock_rpc_device, "connected", False) monkeypatch.setattr( @@ -597,7 +598,7 @@ async def test_rpc_reconnect_error( async_fire_time_changed(hass) await hass.async_block_till_done() - assert hass.states.get("switch.test_switch_0").state == STATE_UNAVAILABLE + assert get_entity_state(hass, "switch.test_switch_0") == STATE_UNAVAILABLE async def test_rpc_polling_connection_error( @@ -615,11 +616,11 @@ async def test_rpc_polling_connection_error( ), ) - assert hass.states.get(entity_id).state == "-63" + assert get_entity_state(hass, entity_id) == "-63" await mock_polling_rpc_update(hass, freezer) - assert hass.states.get(entity_id).state == STATE_UNAVAILABLE + assert get_entity_state(hass, entity_id) == STATE_UNAVAILABLE async def test_rpc_polling_disconnected( @@ -631,11 +632,11 @@ async def test_rpc_polling_disconnected( monkeypatch.setattr(mock_rpc_device, "connected", False) - assert hass.states.get(entity_id).state == "-63" + assert get_entity_state(hass, entity_id) == "-63" await mock_polling_rpc_update(hass, freezer) - assert hass.states.get(entity_id).state == STATE_UNAVAILABLE + assert get_entity_state(hass, entity_id) == STATE_UNAVAILABLE async def test_rpc_update_entry_fw_ver( @@ -649,11 +650,12 @@ async def test_rpc_update_entry_fw_ver( mock_rpc_device.mock_update() await hass.async_block_till_done() + assert entry.unique_id device = dev_reg.async_get_device( identifiers={(DOMAIN, entry.entry_id)}, connections={(CONNECTION_NETWORK_MAC, format_mac(entry.unique_id))}, ) - + assert device assert device.sw_version == "some fw string" monkeypatch.setattr(mock_rpc_device, "firmware_version", "99.0.0") @@ -665,5 +667,5 @@ async def test_rpc_update_entry_fw_ver( identifiers={(DOMAIN, entry.entry_id)}, connections={(CONNECTION_NETWORK_MAC, format_mac(entry.unique_id))}, ) - + assert device assert device.sw_version == "99.0.0"