Fix race in Alexa async_enable_proactive_mode (#92785)

pull/92869/head
Erik Montnemery 2023-05-09 19:58:00 +02:00 committed by GitHub
parent 67c1051305
commit 7d29d584fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 10 deletions

View File

@ -17,11 +17,12 @@ _LOGGER = logging.getLogger(__name__)
class AbstractConfig(ABC):
"""Hold the configuration for Alexa."""
_unsub_proactive_report: asyncio.Task[CALLBACK_TYPE] | None = None
_unsub_proactive_report: CALLBACK_TYPE | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize abstract config."""
self.hass = hass
self._enable_proactive_mode_lock = asyncio.Lock()
self._store = None
async def async_initialize(self):
@ -67,20 +68,17 @@ class AbstractConfig(ABC):
async def async_enable_proactive_mode(self):
"""Enable proactive mode."""
_LOGGER.debug("Enable proactive mode")
if self._unsub_proactive_report is None:
self._unsub_proactive_report = self.hass.async_create_task(
async_enable_proactive_mode(self.hass, self)
async with self._enable_proactive_mode_lock:
if self._unsub_proactive_report is not None:
return
self._unsub_proactive_report = await async_enable_proactive_mode(
self.hass, self
)
try:
await self._unsub_proactive_report
except Exception:
self._unsub_proactive_report = None
raise
async def async_disable_proactive_mode(self):
"""Disable proactive mode."""
_LOGGER.debug("Disable proactive mode")
if unsub_func := await self._unsub_proactive_report:
if unsub_func := self._unsub_proactive_report:
unsub_func()
self._unsub_proactive_report = None

View File

@ -0,0 +1,21 @@
"""Test config."""
import asyncio
from unittest.mock import patch
from homeassistant.core import HomeAssistant
from .test_common import get_default_config
async def test_enable_proactive_mode_in_parallel(hass: HomeAssistant) -> None:
"""Test enabling proactive mode does not happen in parallel."""
config = get_default_config(hass)
with patch(
"homeassistant.components.alexa.config.async_enable_proactive_mode"
) as mock_enable_proactive_mode:
await asyncio.gather(
config.async_enable_proactive_mode(), config.async_enable_proactive_mode()
)
mock_enable_proactive_mode.assert_awaited_once()