Wait for discovery to complete before starting apple_tv (#74133)
parent
6a0ca2b36d
commit
99329ef04f
|
@ -23,6 +23,7 @@ from homeassistant.const import (
|
|||
Platform,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
|
@ -49,6 +50,13 @@ PLATFORMS = [Platform.MEDIA_PLAYER, Platform.REMOTE]
|
|||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry for Apple TV."""
|
||||
manager = AppleTVManager(hass, entry)
|
||||
|
||||
if manager.is_on:
|
||||
await manager.connect_once(raise_missing_credentials=True)
|
||||
if not manager.atv:
|
||||
address = entry.data[CONF_ADDRESS]
|
||||
raise ConfigEntryNotReady(f"Not found at {address}, waiting for discovery")
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager
|
||||
|
||||
async def on_hass_stop(event):
|
||||
|
@ -148,14 +156,14 @@ class AppleTVManager:
|
|||
self.config_entry = config_entry
|
||||
self.hass = hass
|
||||
self.atv = None
|
||||
self._is_on = not config_entry.options.get(CONF_START_OFF, False)
|
||||
self.is_on = not config_entry.options.get(CONF_START_OFF, False)
|
||||
self._connection_attempts = 0
|
||||
self._connection_was_lost = False
|
||||
self._task = None
|
||||
|
||||
async def init(self):
|
||||
"""Initialize power management."""
|
||||
if self._is_on:
|
||||
if self.is_on:
|
||||
await self.connect()
|
||||
|
||||
def connection_lost(self, _):
|
||||
|
@ -186,13 +194,13 @@ class AppleTVManager:
|
|||
|
||||
async def connect(self):
|
||||
"""Connect to device."""
|
||||
self._is_on = True
|
||||
self.is_on = True
|
||||
self._start_connect_loop()
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from device."""
|
||||
_LOGGER.debug("Disconnecting from device")
|
||||
self._is_on = False
|
||||
self.is_on = False
|
||||
try:
|
||||
if self.atv:
|
||||
self.atv.close()
|
||||
|
@ -205,50 +213,53 @@ class AppleTVManager:
|
|||
|
||||
def _start_connect_loop(self):
|
||||
"""Start background connect loop to device."""
|
||||
if not self._task and self.atv is None and self._is_on:
|
||||
if not self._task and self.atv is None and self.is_on:
|
||||
self._task = asyncio.create_task(self._connect_loop())
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Not starting connect loop (%s, %s)", self.atv is None, self._is_on
|
||||
"Not starting connect loop (%s, %s)", self.atv is None, self.is_on
|
||||
)
|
||||
|
||||
async def connect_once(self, raise_missing_credentials):
|
||||
"""Try to connect once."""
|
||||
try:
|
||||
if conf := await self._scan():
|
||||
await self._connect(conf, raise_missing_credentials)
|
||||
except exceptions.AuthenticationError:
|
||||
self.config_entry.async_start_reauth(self.hass)
|
||||
asyncio.create_task(self.disconnect())
|
||||
_LOGGER.exception(
|
||||
"Authentication failed for %s, try reconfiguring device",
|
||||
self.config_entry.data[CONF_NAME],
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Failed to connect")
|
||||
self.atv = None
|
||||
|
||||
async def _connect_loop(self):
|
||||
"""Connect loop background task function."""
|
||||
_LOGGER.debug("Starting connect loop")
|
||||
|
||||
# Try to find device and connect as long as the user has said that
|
||||
# we are allowed to connect and we are not already connected.
|
||||
while self._is_on and self.atv is None:
|
||||
try:
|
||||
conf = await self._scan()
|
||||
if conf:
|
||||
await self._connect(conf)
|
||||
except exceptions.AuthenticationError:
|
||||
self.config_entry.async_start_reauth(self.hass)
|
||||
asyncio.create_task(self.disconnect())
|
||||
_LOGGER.exception(
|
||||
"Authentication failed for %s, try reconfiguring device",
|
||||
self.config_entry.data[CONF_NAME],
|
||||
)
|
||||
while self.is_on and self.atv is None:
|
||||
await self.connect_once(raise_missing_credentials=False)
|
||||
if self.atv is not None:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Failed to connect")
|
||||
self.atv = None
|
||||
self._connection_attempts += 1
|
||||
backoff = min(
|
||||
max(
|
||||
BACKOFF_TIME_LOWER_LIMIT,
|
||||
randrange(2**self._connection_attempts),
|
||||
),
|
||||
BACKOFF_TIME_UPPER_LIMIT,
|
||||
)
|
||||
|
||||
if self.atv is None:
|
||||
self._connection_attempts += 1
|
||||
backoff = min(
|
||||
max(
|
||||
BACKOFF_TIME_LOWER_LIMIT,
|
||||
randrange(2**self._connection_attempts),
|
||||
),
|
||||
BACKOFF_TIME_UPPER_LIMIT,
|
||||
)
|
||||
|
||||
_LOGGER.debug("Reconnecting in %d seconds", backoff)
|
||||
await asyncio.sleep(backoff)
|
||||
_LOGGER.debug("Reconnecting in %d seconds", backoff)
|
||||
await asyncio.sleep(backoff)
|
||||
|
||||
_LOGGER.debug("Connect loop ended")
|
||||
self._task = None
|
||||
|
@ -287,23 +298,33 @@ class AppleTVManager:
|
|||
# it will update the address and reload the config entry when the device is found.
|
||||
return None
|
||||
|
||||
async def _connect(self, conf):
|
||||
async def _connect(self, conf, raise_missing_credentials):
|
||||
"""Connect to device."""
|
||||
credentials = self.config_entry.data[CONF_CREDENTIALS]
|
||||
session = async_get_clientsession(self.hass)
|
||||
|
||||
name = self.config_entry.data[CONF_NAME]
|
||||
missing_protocols = []
|
||||
for protocol_int, creds in credentials.items():
|
||||
protocol = Protocol(int(protocol_int))
|
||||
if conf.get_service(protocol) is not None:
|
||||
conf.set_credentials(protocol, creds)
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Protocol %s not found for %s, functionality will be reduced",
|
||||
protocol.name,
|
||||
self.config_entry.data[CONF_NAME],
|
||||
missing_protocols.append(protocol.name)
|
||||
|
||||
if missing_protocols:
|
||||
missing_protocols_str = ", ".join(missing_protocols)
|
||||
if raise_missing_credentials:
|
||||
raise ConfigEntryNotReady(
|
||||
f"Protocol(s) {missing_protocols_str} not yet found for {name}, waiting for discovery."
|
||||
)
|
||||
_LOGGER.info(
|
||||
"Protocol(s) %s not yet found for %s, trying later",
|
||||
missing_protocols_str,
|
||||
name,
|
||||
)
|
||||
return
|
||||
|
||||
_LOGGER.debug("Connecting to device %s", self.config_entry.data[CONF_NAME])
|
||||
session = async_get_clientsession(self.hass)
|
||||
self.atv = await connect(conf, self.hass.loop, session=session)
|
||||
self.atv.listener = self
|
||||
|
||||
|
|
Loading…
Reference in New Issue