"""Config flow for Apple TV integration.""" from __future__ import annotations import asyncio from collections import deque from ipaddress import ip_address import logging from random import randrange from pyatv import exceptions, pair, scan from pyatv.const import DeviceModel, PairingRequirement, Protocol from pyatv.convert import model_str, protocol_str from pyatv.helpers import get_unique_id import voluptuous as vol from homeassistant import config_entries, data_entry_flow from homeassistant.components import zeroconf from homeassistant.const import CONF_ADDRESS, CONF_NAME, CONF_PIN from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.aiohttp_client import async_get_clientsession from .const import CONF_CREDENTIALS, CONF_IDENTIFIERS, CONF_START_OFF, DOMAIN _LOGGER = logging.getLogger(__name__) DEVICE_INPUT = "device_input" INPUT_PIN_SCHEMA = vol.Schema({vol.Required(CONF_PIN, default=None): int}) DEFAULT_START_OFF = False DISCOVERY_AGGREGATION_TIME = 15 # seconds async def device_scan(hass, identifier, loop): """Scan for a specific device using identifier as filter.""" def _filter_device(dev): if identifier is None: return True if identifier == str(dev.address): return True if identifier == dev.name: return True return any(service.identifier == identifier for service in dev.services) def _host_filter(): try: return [ip_address(identifier)] except ValueError: return None # If we have an address, only probe that address to avoid # broadcast traffic on the network aiozc = await zeroconf.async_get_async_instance(hass) scan_result = await scan(loop, timeout=3, hosts=_host_filter(), aiozc=aiozc) matches = [atv for atv in scan_result if _filter_device(atv)] if matches: return matches[0], matches[0].all_identifiers return None, None class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow for Apple TV.""" VERSION = 1 @staticmethod @callback def async_get_options_flow(config_entry): """Get options flow for this handler.""" return AppleTVOptionsFlow(config_entry) def __init__(self): """Initialize a new AppleTVConfigFlow.""" self.scan_filter = None self.atv = None self.atv_identifiers = None self.protocol = None self.pairing = None self.credentials = {} # Protocol -> credentials self.protocols_to_pair = deque() @property def device_identifier(self): """Return a identifier for the config entry. A device has multiple unique identifiers, but Home Assistant only supports one per config entry. Normally, a "main identifier" is determined by pyatv by first collecting all identifiers and then picking one in a pre-determine order. Under normal circumstances, this works fine but if a service is missing or removed due to deprecation (which happened with MRP), then another identifier will be calculated instead. To fix this, all identifiers belonging to a device is stored with the config entry and one of them (could be random) is used as unique_id for said entry. When a new (zeroconf) service or device is discovered, the identifier is first used to look up if it belongs to an existing config entry. If that's the case, the unique_id from that entry is re-used, otherwise the newly discovered identifier is used instead. """ all_identifiers = set(self.atv.all_identifiers) if unique_id := self._entry_unique_id_from_identifers(all_identifiers): return unique_id return self.atv.identifier @callback def _entry_unique_id_from_identifers(self, all_identifiers: set[str]) -> str | None: """Search existing entries for an identifier and return the unique id.""" for entry in self._async_current_entries(): if all_identifiers.intersection( entry.data.get(CONF_IDENTIFIERS, [entry.unique_id]) ): return entry.unique_id return None async def async_step_reauth(self, user_input=None): """Handle initial step when updating invalid credentials.""" self.context["title_placeholders"] = { "name": user_input[CONF_NAME], "type": "Apple TV", } self.scan_filter = self.unique_id self.context["identifier"] = self.unique_id return await self.async_step_reconfigure() async def async_step_reconfigure(self, user_input=None): """Inform user that reconfiguration is about to start.""" if user_input is not None: return await self.async_find_device_wrapper( self.async_pair_next_protocol, allow_exist=True ) return self.async_show_form(step_id="reconfigure") async def async_step_user(self, user_input=None): """Handle the initial step.""" errors = {} if user_input is not None: self.scan_filter = user_input[DEVICE_INPUT] try: await self.async_find_device() except DeviceNotFound: errors["base"] = "no_devices_found" except DeviceAlreadyConfigured: errors["base"] = "already_configured" except Exception: # pylint: disable=broad-except _LOGGER.exception("Unexpected exception") errors["base"] = "unknown" else: await self.async_set_unique_id( self.device_identifier, raise_on_progress=False ) self.context["all_identifiers"] = self.atv.all_identifiers return await self.async_step_confirm() return self.async_show_form( step_id="user", data_schema=vol.Schema({vol.Required(DEVICE_INPUT): str}), errors=errors, ) async def async_step_zeroconf( self, discovery_info: zeroconf.ZeroconfServiceInfo ) -> data_entry_flow.FlowResult: """Handle device found via zeroconf.""" host = discovery_info.host self._async_abort_entries_match({CONF_ADDRESS: host}) service_type = discovery_info.type[:-1] # Remove leading . name = discovery_info.name.replace(f".{service_type}.", "") properties = discovery_info.properties # Extract unique identifier from service unique_id = get_unique_id(service_type, name, properties) if unique_id is None: return self.async_abort(reason="unknown") if existing_unique_id := self._entry_unique_id_from_identifers({unique_id}): await self.async_set_unique_id(existing_unique_id) self._abort_if_unique_id_configured(updates={CONF_ADDRESS: host}) self._async_abort_entries_match({CONF_ADDRESS: host}) await self._async_aggregate_discoveries(host, unique_id) # Scan for the device in order to extract _all_ unique identifiers assigned to # it. Not doing it like this will yield multiple config flows for the same # device, one per protocol, which is undesired. self.scan_filter = host return await self.async_find_device_wrapper(self.async_found_zeroconf_device) async def _async_aggregate_discoveries(self, host: str, unique_id: str) -> None: """Wait for multiple zeroconf services to be discovered an aggregate them.""" # # Suppose we have a device with three services: A, B and C. Let's assume # service A is discovered by Zeroconf, triggering a device scan that also finds # service B but *not* C. An identifier is picked from one of the services and # used as unique_id. The select process is deterministic (let's say in order A, # B and C) but in practice that doesn't matter. So, a flow is set up for the # device with unique_id set to "A" for services A and B. # # Now, service C is found and the same thing happens again but only service B # is found. In this case, unique_id will be set to "B" which is problematic # since both flows really represent the same device. They will however end up # as two separate flows. # # To solve this, all identifiers are stored as # "all_identifiers" in the flow context. When a new service is discovered, the # code below will check these identifiers for all active flows and abort if a # match is found. Before aborting, the original flow is updated with any # potentially new identifiers. In the example above, when service C is # discovered, the identifier of service C will be inserted into # "all_identifiers" of the original flow (making the device complete). # # Wait DISCOVERY_AGGREGATION_TIME for multiple services to be # discovered via zeroconf. Once the first service is discovered # this allows other services to be discovered inside the time # window before triggering a scan of the device. This prevents # multiple scans of the device at the same time since each # apple_tv device has multiple services that are discovered by # zeroconf. # self._async_check_and_update_in_progress(host, unique_id) await asyncio.sleep(DISCOVERY_AGGREGATION_TIME) # Check again after sleeping in case another flow # has made progress while we yielded to the event loop self._async_check_and_update_in_progress(host, unique_id) # Host must only be set AFTER checking and updating in progress # flows or we will have a race condition where no flows move forward. self.context[CONF_ADDRESS] = host @callback def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None: """Check for in-progress flows and update them with identifiers if needed.""" for flow in self._async_in_progress(include_uninitialized=True): context = flow["context"] if ( context.get("source") != config_entries.SOURCE_ZEROCONF or context.get(CONF_ADDRESS) != host ): continue if ( "all_identifiers" in context and unique_id not in context["all_identifiers"] ): # Add potentially new identifiers from this device to the existing flow context["all_identifiers"].append(unique_id) raise data_entry_flow.AbortFlow("already_in_progress") async def async_found_zeroconf_device(self, user_input=None): """Handle device found after Zeroconf discovery.""" self.context["all_identifiers"] = self.atv.all_identifiers # Also abort if an integration with this identifier already exists await self.async_set_unique_id(self.device_identifier) # but be sure to update the address if its changed so the scanner # will probe the new address self._abort_if_unique_id_configured( updates={CONF_ADDRESS: str(self.atv.address)} ) self.context["identifier"] = self.unique_id return await self.async_step_confirm() async def async_find_device_wrapper(self, next_func, allow_exist=False): """Find a specific device and call another function when done. This function will do error handling and bail out when an error occurs. """ try: await self.async_find_device(allow_exist) except DeviceNotFound: return self.async_abort(reason="no_devices_found") except DeviceAlreadyConfigured: return self.async_abort(reason="already_configured") except Exception: # pylint: disable=broad-except _LOGGER.exception("Unexpected exception") return self.async_abort(reason="unknown") return await next_func() async def async_find_device(self, allow_exist=False): """Scan for the selected device to discover services.""" self.atv, self.atv_identifiers = await device_scan( self.hass, self.scan_filter, self.hass.loop ) if not self.atv: raise DeviceNotFound() # Protocols supported by the device are prospects for pairing self.protocols_to_pair = deque( service.protocol for service in self.atv.services if service.enabled ) dev_info = self.atv.device_info self.context["title_placeholders"] = { "name": self.atv.name, "type": ( dev_info.raw_model if dev_info.model == DeviceModel.Unknown and dev_info.raw_model else model_str(dev_info.model) ), } all_identifiers = set(self.atv.all_identifiers) discovered_ip_address = str(self.atv.address) for entry in self._async_current_entries(): if not all_identifiers.intersection( entry.data.get(CONF_IDENTIFIERS, [entry.unique_id]) ): continue if entry.data.get(CONF_ADDRESS) != discovered_ip_address: self.hass.config_entries.async_update_entry( entry, data={**entry.data, CONF_ADDRESS: discovered_ip_address}, ) self.hass.async_create_task( self.hass.config_entries.async_reload(entry.entry_id) ) if not allow_exist: raise DeviceAlreadyConfigured() async def async_step_confirm(self, user_input=None): """Handle user-confirmation of discovered node.""" if user_input is not None: expected_identifier_count = len(self.context["all_identifiers"]) # If number of services found during device scan mismatch number of # identifiers collected during Zeroconf discovery, then trigger a new scan # with hopes of finding all services. if len(self.atv.all_identifiers) != expected_identifier_count: try: await self.async_find_device() except DeviceNotFound: return self.async_abort(reason="device_not_found") # If all services still were not found, bail out with an error if len(self.atv.all_identifiers) != expected_identifier_count: return self.async_abort(reason="inconsistent_device") return await self.async_pair_next_protocol() return self.async_show_form( step_id="confirm", description_placeholders={ "name": self.atv.name, "type": model_str(self.atv.device_info.model), }, ) async def async_pair_next_protocol(self): """Start pairing process for the next available protocol.""" await self._async_cleanup() # Any more protocols to pair? Else bail out here if not self.protocols_to_pair: return await self._async_get_entry() self.protocol = self.protocols_to_pair.popleft() service = self.atv.get_service(self.protocol) # Service requires a password if service.requires_password: return await self.async_step_password() # Figure out, depending on protocol, what kind of pairing is needed if service.pairing == PairingRequirement.Unsupported: _LOGGER.debug("%s does not support pairing", self.protocol) return await self.async_pair_next_protocol() if service.pairing == PairingRequirement.Disabled: return await self.async_step_protocol_disabled() if service.pairing == PairingRequirement.NotNeeded: _LOGGER.debug("%s does not require pairing", self.protocol) self.credentials[self.protocol.value] = None return await self.async_pair_next_protocol() _LOGGER.debug("%s requires pairing", self.protocol) # Protocol specific arguments pair_args = {} if self.protocol == Protocol.DMAP: pair_args["name"] = "Home Assistant" pair_args["zeroconf"] = await zeroconf.async_get_instance(self.hass) # Initiate the pairing process abort_reason = None session = async_get_clientsession(self.hass) self.pairing = await pair( self.atv, self.protocol, self.hass.loop, session=session, **pair_args ) try: await self.pairing.begin() except exceptions.ConnectionFailedError: return await self.async_step_service_problem() except exceptions.BackOffError: abort_reason = "backoff" except exceptions.PairingError: _LOGGER.exception("Authentication problem") abort_reason = "invalid_auth" except Exception: # pylint: disable=broad-except _LOGGER.exception("Unexpected exception") abort_reason = "unknown" if abort_reason: await self._async_cleanup() return self.async_abort(reason=abort_reason) # Choose step depending on if PIN is required from user or not if self.pairing.device_provides_pin: return await self.async_step_pair_with_pin() return await self.async_step_pair_no_pin() async def async_step_protocol_disabled(self, user_input=None): """Inform user that a protocol is disabled and cannot be paired.""" if user_input is not None: return await self.async_pair_next_protocol() return self.async_show_form( step_id="protocol_disabled", description_placeholders={"protocol": protocol_str(self.protocol)}, ) async def async_step_pair_with_pin(self, user_input=None): """Handle pairing step where a PIN is required from the user.""" errors = {} if user_input is not None: try: self.pairing.pin(user_input[CONF_PIN]) await self.pairing.finish() self.credentials[self.protocol.value] = self.pairing.service.credentials return await self.async_pair_next_protocol() except exceptions.PairingError: _LOGGER.exception("Authentication problem") errors["base"] = "invalid_auth" except Exception: # pylint: disable=broad-except _LOGGER.exception("Unexpected exception") errors["base"] = "unknown" return self.async_show_form( step_id="pair_with_pin", data_schema=INPUT_PIN_SCHEMA, errors=errors, description_placeholders={"protocol": protocol_str(self.protocol)}, ) async def async_step_pair_no_pin(self, user_input=None): """Handle step where user has to enter a PIN on the device.""" if user_input is not None: await self.pairing.finish() if self.pairing.has_paired: self.credentials[self.protocol.value] = self.pairing.service.credentials return await self.async_pair_next_protocol() await self.pairing.close() return self.async_abort(reason="device_did_not_pair") pin = randrange(1000, stop=10000) self.pairing.pin(pin) return self.async_show_form( step_id="pair_no_pin", description_placeholders={ "protocol": protocol_str(self.protocol), "pin": pin, }, ) async def async_step_service_problem(self, user_input=None): """Inform user that a service will not be added.""" if user_input is not None: return await self.async_pair_next_protocol() return self.async_show_form( step_id="service_problem", description_placeholders={"protocol": protocol_str(self.protocol)}, ) async def async_step_password(self, user_input=None): """Inform user that password is not supported.""" if user_input is not None: return await self.async_pair_next_protocol() return self.async_show_form( step_id="password", description_placeholders={"protocol": protocol_str(self.protocol)}, ) async def _async_cleanup(self): """Clean up allocated resources.""" if self.pairing is not None: await self.pairing.close() self.pairing = None async def _async_get_entry(self): """Return config entry or update existing config entry.""" # Abort if no protocols were paired if not self.credentials: return self.async_abort(reason="setup_failed") data = { CONF_NAME: self.atv.name, CONF_CREDENTIALS: self.credentials, CONF_ADDRESS: str(self.atv.address), CONF_IDENTIFIERS: self.atv_identifiers, } existing_entry = await self.async_set_unique_id( self.device_identifier, raise_on_progress=False ) # If an existing config entry is updated, then this was a re-auth if existing_entry: self.hass.config_entries.async_update_entry( existing_entry, data=data, unique_id=self.unique_id ) self.hass.async_create_task( self.hass.config_entries.async_reload(existing_entry.entry_id) ) return self.async_abort(reason="reauth_successful") return self.async_create_entry(title=self.atv.name, data=data) class AppleTVOptionsFlow(config_entries.OptionsFlow): """Handle Apple TV options.""" def __init__(self, config_entry): """Initialize Apple TV options flow.""" self.config_entry = config_entry self.options = dict(config_entry.options) async def async_step_init(self, user_input=None): """Manage the Apple TV options.""" if user_input is not None: self.options[CONF_START_OFF] = user_input[CONF_START_OFF] return self.async_create_entry(title="", data=self.options) return self.async_show_form( step_id="init", data_schema=vol.Schema( { vol.Optional( CONF_START_OFF, default=self.config_entry.options.get( CONF_START_OFF, DEFAULT_START_OFF ), ): bool, } ), ) class DeviceNotFound(HomeAssistantError): """Error to indicate device could not be found.""" class DeviceAlreadyConfigured(HomeAssistantError): """Error to indicate device is already configured."""