Wait for discovery to complete before starting apple_tv (#74133)

pull/73834/head
J. Nick Koston 2022-06-29 03:13:10 -05:00 committed by GitHub
parent 6a0ca2b36d
commit 99329ef04f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 63 additions and 42 deletions

View File

@ -23,6 +23,7 @@ from homeassistant.const import (
Platform,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.dispatcher import (
@ -49,6 +50,13 @@ PLATFORMS = [Platform.MEDIA_PLAYER, Platform.REMOTE]
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry for Apple TV."""
manager = AppleTVManager(hass, entry)
if manager.is_on:
await manager.connect_once(raise_missing_credentials=True)
if not manager.atv:
address = entry.data[CONF_ADDRESS]
raise ConfigEntryNotReady(f"Not found at {address}, waiting for discovery")
hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager
async def on_hass_stop(event):
@ -148,14 +156,14 @@ class AppleTVManager:
self.config_entry = config_entry
self.hass = hass
self.atv = None
self._is_on = not config_entry.options.get(CONF_START_OFF, False)
self.is_on = not config_entry.options.get(CONF_START_OFF, False)
self._connection_attempts = 0
self._connection_was_lost = False
self._task = None
async def init(self):
"""Initialize power management."""
if self._is_on:
if self.is_on:
await self.connect()
def connection_lost(self, _):
@ -186,13 +194,13 @@ class AppleTVManager:
async def connect(self):
"""Connect to device."""
self._is_on = True
self.is_on = True
self._start_connect_loop()
async def disconnect(self):
"""Disconnect from device."""
_LOGGER.debug("Disconnecting from device")
self._is_on = False
self.is_on = False
try:
if self.atv:
self.atv.close()
@ -205,50 +213,53 @@ class AppleTVManager:
def _start_connect_loop(self):
"""Start background connect loop to device."""
if not self._task and self.atv is None and self._is_on:
if not self._task and self.atv is None and self.is_on:
self._task = asyncio.create_task(self._connect_loop())
else:
_LOGGER.debug(
"Not starting connect loop (%s, %s)", self.atv is None, self._is_on
"Not starting connect loop (%s, %s)", self.atv is None, self.is_on
)
async def connect_once(self, raise_missing_credentials):
"""Try to connect once."""
try:
if conf := await self._scan():
await self._connect(conf, raise_missing_credentials)
except exceptions.AuthenticationError:
self.config_entry.async_start_reauth(self.hass)
asyncio.create_task(self.disconnect())
_LOGGER.exception(
"Authentication failed for %s, try reconfiguring device",
self.config_entry.data[CONF_NAME],
)
return
except asyncio.CancelledError:
pass
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Failed to connect")
self.atv = None
async def _connect_loop(self):
"""Connect loop background task function."""
_LOGGER.debug("Starting connect loop")
# Try to find device and connect as long as the user has said that
# we are allowed to connect and we are not already connected.
while self._is_on and self.atv is None:
try:
conf = await self._scan()
if conf:
await self._connect(conf)
except exceptions.AuthenticationError:
self.config_entry.async_start_reauth(self.hass)
asyncio.create_task(self.disconnect())
_LOGGER.exception(
"Authentication failed for %s, try reconfiguring device",
self.config_entry.data[CONF_NAME],
)
while self.is_on and self.atv is None:
await self.connect_once(raise_missing_credentials=False)
if self.atv is not None:
break
except asyncio.CancelledError:
pass
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Failed to connect")
self.atv = None
self._connection_attempts += 1
backoff = min(
max(
BACKOFF_TIME_LOWER_LIMIT,
randrange(2**self._connection_attempts),
),
BACKOFF_TIME_UPPER_LIMIT,
)
if self.atv is None:
self._connection_attempts += 1
backoff = min(
max(
BACKOFF_TIME_LOWER_LIMIT,
randrange(2**self._connection_attempts),
),
BACKOFF_TIME_UPPER_LIMIT,
)
_LOGGER.debug("Reconnecting in %d seconds", backoff)
await asyncio.sleep(backoff)
_LOGGER.debug("Reconnecting in %d seconds", backoff)
await asyncio.sleep(backoff)
_LOGGER.debug("Connect loop ended")
self._task = None
@ -287,23 +298,33 @@ class AppleTVManager:
# it will update the address and reload the config entry when the device is found.
return None
async def _connect(self, conf):
async def _connect(self, conf, raise_missing_credentials):
"""Connect to device."""
credentials = self.config_entry.data[CONF_CREDENTIALS]
session = async_get_clientsession(self.hass)
name = self.config_entry.data[CONF_NAME]
missing_protocols = []
for protocol_int, creds in credentials.items():
protocol = Protocol(int(protocol_int))
if conf.get_service(protocol) is not None:
conf.set_credentials(protocol, creds)
else:
_LOGGER.warning(
"Protocol %s not found for %s, functionality will be reduced",
protocol.name,
self.config_entry.data[CONF_NAME],
missing_protocols.append(protocol.name)
if missing_protocols:
missing_protocols_str = ", ".join(missing_protocols)
if raise_missing_credentials:
raise ConfigEntryNotReady(
f"Protocol(s) {missing_protocols_str} not yet found for {name}, waiting for discovery."
)
_LOGGER.info(
"Protocol(s) %s not yet found for %s, trying later",
missing_protocols_str,
name,
)
return
_LOGGER.debug("Connecting to device %s", self.config_entry.data[CONF_NAME])
session = async_get_clientsession(self.hass)
self.atv = await connect(conf, self.hass.loop, session=session)
self.atv.listener = self