Add a lock to homekit_controller platform loads (#116539)

pull/116512/head
J. Nick Koston 2024-05-01 19:23:43 -05:00 committed by GitHub
parent 713ce0dd17
commit 657c9ec25b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 18 additions and 14 deletions

View File

@ -153,6 +153,7 @@ class HKDevice:
self._subscriptions: dict[tuple[int, int], set[CALLBACK_TYPE]] = {} self._subscriptions: dict[tuple[int, int], set[CALLBACK_TYPE]] = {}
self._pending_subscribes: set[tuple[int, int]] = set() self._pending_subscribes: set[tuple[int, int]] = set()
self._subscribe_timer: CALLBACK_TYPE | None = None self._subscribe_timer: CALLBACK_TYPE | None = None
self._load_platforms_lock = asyncio.Lock()
@property @property
def entity_map(self) -> Accessories: def entity_map(self) -> Accessories:
@ -327,7 +328,8 @@ class HKDevice:
) )
# BLE devices always get an RSSI sensor as well # BLE devices always get an RSSI sensor as well
if "sensor" not in self.platforms: if "sensor" not in self.platforms:
await self._async_load_platforms({"sensor"}) async with self._load_platforms_lock:
await self._async_load_platforms({"sensor"})
@callback @callback
def _async_start_polling(self) -> None: def _async_start_polling(self) -> None:
@ -804,6 +806,7 @@ class HKDevice:
async def _async_load_platforms(self, platforms: set[str]) -> None: async def _async_load_platforms(self, platforms: set[str]) -> None:
"""Load a group of platforms.""" """Load a group of platforms."""
assert self._load_platforms_lock.locked(), "Must be called with lock held"
if not (to_load := platforms - self.platforms): if not (to_load := platforms - self.platforms):
return return
self.platforms.update(to_load) self.platforms.update(to_load)
@ -813,22 +816,23 @@ class HKDevice:
async def async_load_platforms(self) -> None: async def async_load_platforms(self) -> None:
"""Load any platforms needed by this HomeKit device.""" """Load any platforms needed by this HomeKit device."""
to_load: set[str] = set() async with self._load_platforms_lock:
for accessory in self.entity_map.accessories: to_load: set[str] = set()
for service in accessory.services: for accessory in self.entity_map.accessories:
if service.type in HOMEKIT_ACCESSORY_DISPATCH: for service in accessory.services:
platform = HOMEKIT_ACCESSORY_DISPATCH[service.type] if service.type in HOMEKIT_ACCESSORY_DISPATCH:
if platform not in self.platforms: platform = HOMEKIT_ACCESSORY_DISPATCH[service.type]
to_load.add(platform)
for char in service.characteristics:
if char.type in CHARACTERISTIC_PLATFORMS:
platform = CHARACTERISTIC_PLATFORMS[char.type]
if platform not in self.platforms: if platform not in self.platforms:
to_load.add(platform) to_load.add(platform)
if to_load: for char in service.characteristics:
await self._async_load_platforms(to_load) if char.type in CHARACTERISTIC_PLATFORMS:
platform = CHARACTERISTIC_PLATFORMS[char.type]
if platform not in self.platforms:
to_load.add(platform)
if to_load:
await self._async_load_platforms(to_load)
@callback @callback
def async_update_available_state(self, *_: Any) -> None: def async_update_available_state(self, *_: Any) -> None: