From 0628f967130a2aaef8ca4fff7b6a81aa5550864c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 25 Aug 2024 09:21:15 -1000 Subject: [PATCH] Ensure all chars are polling when requesting manual update in homekit_controller (#124582) related issue #123963 --- .../homekit_controller/connection.py | 7 +++- .../homekit_controller/test_connection.py | 40 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/homeassistant/components/homekit_controller/connection.py b/homeassistant/components/homekit_controller/connection.py index 4da907daf3e..934e7e883ae 100644 --- a/homeassistant/components/homekit_controller/connection.py +++ b/homeassistant/components/homekit_controller/connection.py @@ -154,6 +154,7 @@ class HKDevice: self._pending_subscribes: set[tuple[int, int]] = set() self._subscribe_timer: CALLBACK_TYPE | None = None self._load_platforms_lock = asyncio.Lock() + self._full_update_requested: bool = False @property def entity_map(self) -> Accessories: @@ -841,6 +842,7 @@ class HKDevice: async def async_request_update(self, now: datetime | None = None) -> None: """Request an debounced update from the accessory.""" + self._full_update_requested = True await self._debounced_update.async_call() async def async_update(self, now: datetime | None = None) -> None: @@ -849,7 +851,8 @@ class HKDevice: accessories = self.entity_map.accessories if ( - len(accessories) == 1 + not self._full_update_requested + and len(accessories) == 1 and self.available and not (to_poll - self.watchable_characteristics) and self.pairing.is_available @@ -879,6 +882,8 @@ class HKDevice: firmware_iid = accessory_info[CharacteristicsTypes.FIRMWARE_REVISION].iid to_poll = {(first_accessory.aid, firmware_iid)} + self._full_update_requested = False + if not to_poll: self.async_update_available_state() _LOGGER.debug( diff --git a/tests/components/homekit_controller/test_connection.py b/tests/components/homekit_controller/test_connection.py index 503ff171533..7ea791f9a1e 100644 --- a/tests/components/homekit_controller/test_connection.py +++ b/tests/components/homekit_controller/test_connection.py @@ -12,6 +12,7 @@ from aiohomekit.testing import FakeController import pytest from homeassistant.components.homekit_controller.const import ( + DEBOUNCE_COOLDOWN, DOMAIN, IDENTIFIER_ACCESSORY_ID, IDENTIFIER_LEGACY_ACCESSORY_ID, @@ -22,12 +23,14 @@ from homeassistant.const import STATE_OFF, STATE_UNAVAILABLE from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers.entity_component import async_update_entity from .common import ( setup_accessories_from_file, setup_platform, setup_test_accessories, setup_test_component, + time_changed, ) from tests.common import MockConfigEntry @@ -399,3 +402,40 @@ async def test_poll_firmware_version_only_all_watchable_accessory_mode( state = await helper.poll_and_get_state() assert state.state == STATE_OFF assert mock_get_characteristics.call_count == 8 + + +async def test_manual_poll_all_chars( + hass: HomeAssistant, get_next_aid: Callable[[], int] +) -> None: + """Test that a manual poll will check all chars.""" + + def _create_accessory(accessory: Accessory) -> Service: + service = accessory.add_service(ServicesTypes.LIGHTBULB, name="TestDevice") + + on_char = service.add_char(CharacteristicsTypes.ON) + on_char.value = 0 + + brightness = service.add_char(CharacteristicsTypes.BRIGHTNESS) + brightness.value = 0 + + return service + + helper = await setup_test_component(hass, get_next_aid(), _create_accessory) + + with mock.patch.object( + helper.pairing, + "get_characteristics", + wraps=helper.pairing.get_characteristics, + ) as mock_get_characteristics: + # Initial state is that the light is off + await helper.poll_and_get_state() + # Verify only firmware version is polled + assert mock_get_characteristics.call_args_list[0][0][0] == {(1, 7)} + + # Now do a manual poll to ensure all chars are polled + mock_get_characteristics.reset_mock() + await async_update_entity(hass, helper.entity_id) + await time_changed(hass, 60) + await time_changed(hass, DEBOUNCE_COOLDOWN) + await hass.async_block_till_done() + assert len(mock_get_characteristics.call_args_list[0][0][0]) > 1