Add additional coverage to the ESPHome manager (#114265)

pull/114354/head
J. Nick Koston 2024-03-27 20:52:45 -10:00 committed by GitHub
parent ae0b41f7a7
commit bec45dacf0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 83 additions and 2 deletions

View File

@ -170,7 +170,7 @@ class ESPHomeManager:
self.entry_data = entry_data
async def on_stop(self, event: Event) -> None:
"""Cleanup the socket client on HA stop."""
"""Cleanup the socket client on HA close."""
await cleanup_instance(self.hass, self.entry)
@property

View File

@ -188,6 +188,7 @@ class MockESPHomeDevice:
self.service_call_callback: Callable[[HomeassistantServiceCall], None]
self.on_disconnect: Callable[[bool], None]
self.on_connect: Callable[[bool], None]
self.on_connect_error: Callable[[Exception], None]
self.home_assistant_state_subscription_callback: Callable[
[str, str | None], None
]
@ -222,10 +223,20 @@ class MockESPHomeDevice:
"""Set the connect callback."""
self.on_connect = on_connect
def set_on_connect_error(
self, on_connect_error: Callable[[Exception], None]
) -> None:
"""Set the connect error callback."""
self.on_connect_error = on_connect_error
async def mock_connect(self) -> None:
"""Mock connecting."""
await self.on_connect()
async def mock_connect_error(self, exc: Exception) -> None:
"""Mock connect error."""
await self.on_connect_error(exc)
def set_home_assistant_state_subscription_callback(
self,
on_state_sub: Callable[[str, str | None], None],
@ -309,6 +320,7 @@ async def _mock_generic_device_entry(
super().__init__(*args, **kwargs)
mock_device.set_on_disconnect(kwargs["on_disconnect"])
mock_device.set_on_connect(kwargs["on_connect"])
mock_device.set_on_connect_error(kwargs["on_connect_error"])
self._try_connect = self.mock_try_connect
async def mock_try_connect(self):

View File

@ -11,6 +11,9 @@ from aioesphomeapi import (
EntityInfo,
EntityState,
HomeassistantServiceCall,
InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,
RequiresEncryptionAPIError,
UserService,
UserServiceArg,
UserServiceArgType,
@ -25,7 +28,12 @@ from homeassistant.components.esphome.const import (
DOMAIN,
STABLE_BLE_VERSION_STR,
)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT
from homeassistant.const import (
CONF_HOST,
CONF_PASSWORD,
CONF_PORT,
EVENT_HOMEASSISTANT_CLOSE,
)
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import device_registry as dr, issue_registry as ir
@ -1083,3 +1091,64 @@ async def test_esphome_device_with_compilation_time(
connections={(dr.CONNECTION_NETWORK_MAC, entry.unique_id)}
)
assert "comp_time" in dev.sw_version
async def test_disconnects_at_close_event(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test the device is disconnected at the close event."""
await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
device_info={"compilation_time": "comp_time"},
states=[],
)
await hass.async_block_till_done()
assert mock_client.disconnect.call_count == 0
hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
await hass.async_block_till_done()
assert mock_client.disconnect.call_count == 1
@pytest.mark.parametrize(
"error",
[
RequiresEncryptionAPIError,
InvalidEncryptionKeyAPIError,
InvalidAuthAPIError,
],
)
async def test_start_reauth(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
error: Exception,
) -> None:
"""Test exceptions on connect error trigger reauth."""
device = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
device_info={"compilation_time": "comp_time"},
states=[],
)
await hass.async_block_till_done()
await device.mock_connect_error(error("fail"))
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress(DOMAIN)
assert len(flows) == 1
flow = flows[0]
assert flow["context"]["source"] == "reauth"