core/homeassistant/components/apple_tv/config_flow.py

400 lines
14 KiB
Python

"""Config flow for Apple TV integration."""
from ipaddress import ip_address
import logging
from random import randrange
from pyatv import exceptions, pair, scan
from pyatv.const import Protocol
from pyatv.convert import protocol_str
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.components import zeroconf
from homeassistant.const import CONF_ADDRESS, CONF_NAME, CONF_PIN, CONF_PROTOCOL
from homeassistant.core import callback
from homeassistant.data_entry_flow import AbortFlow, FlowResult
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import CONF_CREDENTIALS, CONF_IDENTIFIER, CONF_START_OFF, DOMAIN
_LOGGER = logging.getLogger(__name__)
DEVICE_INPUT = "device_input"
INPUT_PIN_SCHEMA = vol.Schema({vol.Required(CONF_PIN, default=None): int})
DEFAULT_START_OFF = False
PROTOCOL_PRIORITY = [Protocol.MRP, Protocol.DMAP, Protocol.AirPlay]
async def device_scan(identifier, loop, cache=None):
"""Scan for a specific device using identifier as filter."""
def _filter_device(dev):
if identifier is None:
return True
if identifier == str(dev.address):
return True
if identifier == dev.name:
return True
return any(service.identifier == identifier for service in dev.services)
def _host_filter():
try:
return [ip_address(identifier)]
except ValueError:
return None
if cache:
matches = [atv for atv in cache if _filter_device(atv)]
if matches:
return cache, matches[0]
for hosts in (_host_filter(), None):
scan_result = await scan(loop, timeout=3, hosts=hosts)
matches = [atv for atv in scan_result if _filter_device(atv)]
if matches:
return scan_result, matches[0]
return scan_result, None
def is_valid_credentials(credentials):
"""Verify that credentials are valid for establishing a connection."""
return (
credentials.get(Protocol.MRP.value) is not None
or credentials.get(Protocol.DMAP.value) is not None
)
class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Apple TV."""
VERSION = 1
@staticmethod
@callback
def async_get_options_flow(config_entry):
"""Get options flow for this handler."""
return AppleTVOptionsFlow(config_entry)
def __init__(self):
"""Initialize a new AppleTVConfigFlow."""
self.target_device = None
self.scan_result = None
self.atv = None
self.protocol = None
self.pairing = None
self.credentials = {} # Protocol -> credentials
async def async_step_reauth(self, info):
"""Handle initial step when updating invalid credentials."""
await self.async_set_unique_id(info[CONF_IDENTIFIER])
self.target_device = info[CONF_IDENTIFIER]
self.context["title_placeholders"] = {"name": info[CONF_NAME]}
self.context["identifier"] = self.unique_id
return await self.async_step_reconfigure()
async def async_step_reconfigure(self, user_input=None):
"""Inform user that reconfiguration is about to start."""
if user_input is not None:
return await self.async_find_device_wrapper(
self.async_begin_pairing, allow_exist=True
)
return self.async_show_form(step_id="reconfigure")
async def async_step_user(self, user_input=None):
"""Handle the initial step."""
# Be helpful to the user and look for devices
if self.scan_result is None:
self.scan_result, _ = await device_scan(None, self.hass.loop)
errors = {}
default_suggestion = self._prefill_identifier()
if user_input is not None:
self.target_device = user_input[DEVICE_INPUT]
try:
await self.async_find_device()
except DeviceNotFound:
errors["base"] = "no_devices_found"
except DeviceAlreadyConfigured:
errors["base"] = "already_configured"
except exceptions.NoServiceError:
errors["base"] = "no_usable_service"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
await self.async_set_unique_id(
self.atv.identifier, raise_on_progress=False
)
return await self.async_step_confirm()
return self.async_show_form(
step_id="user",
data_schema=vol.Schema(
{vol.Required(DEVICE_INPUT, default=default_suggestion): str}
),
errors=errors,
description_placeholders={"devices": self._devices_str()},
)
async def async_step_zeroconf(
self, discovery_info: zeroconf.ZeroconfServiceInfo
) -> FlowResult:
"""Handle device found via zeroconf."""
service_type = discovery_info.type
properties = discovery_info.properties
if service_type == "_mediaremotetv._tcp.local.":
identifier = properties["UniqueIdentifier"]
name = properties["Name"]
elif service_type == "_touch-able._tcp.local.":
identifier = discovery_info.name.split(".")[0]
name = properties["CtlN"]
else:
return self.async_abort(reason="unknown")
await self.async_set_unique_id(identifier)
self._abort_if_unique_id_configured()
self.context["identifier"] = self.unique_id
self.context["title_placeholders"] = {"name": name}
self.target_device = identifier
return await self.async_find_device_wrapper(self.async_step_confirm)
async def async_find_device_wrapper(self, next_func, allow_exist=False):
"""Find a specific device and call another function when done.
This function will do error handling and bail out when an error
occurs.
"""
try:
await self.async_find_device(allow_exist)
except DeviceNotFound:
return self.async_abort(reason="no_devices_found")
except DeviceAlreadyConfigured:
return self.async_abort(reason="already_configured")
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
return self.async_abort(reason="unknown")
return await next_func()
async def async_find_device(self, allow_exist=False):
"""Scan for the selected device to discover services."""
self.scan_result, self.atv = await device_scan(
self.target_device, self.hass.loop, cache=self.scan_result
)
if not self.atv:
raise DeviceNotFound()
self.protocol = self.atv.main_service().protocol
if not allow_exist:
for identifier in self.atv.all_identifiers:
if identifier in self._async_current_ids():
raise DeviceAlreadyConfigured()
# If credentials were found, save them
for service in self.atv.services:
if service.credentials:
self.credentials[service.protocol.value] = service.credentials
async def async_step_confirm(self, user_input=None):
"""Handle user-confirmation of discovered node."""
if user_input is not None:
return await self.async_begin_pairing()
return self.async_show_form(
step_id="confirm", description_placeholders={"name": self.atv.name}
)
async def async_begin_pairing(self):
"""Start pairing process for the next available protocol."""
self.protocol = self._next_protocol_to_pair()
# Dispose previous pairing sessions
if self.pairing is not None:
await self.pairing.close()
self.pairing = None
# Any more protocols to pair? Else bail out here
if not self.protocol:
await self.async_set_unique_id(self.atv.main_service().identifier)
return self._async_get_entry(
self.atv.main_service().protocol,
self.atv.name,
self.credentials,
self.atv.address,
)
# Initiate the pairing process
abort_reason = None
session = async_get_clientsession(self.hass)
self.pairing = await pair(
self.atv, self.protocol, self.hass.loop, session=session
)
try:
await self.pairing.begin()
except exceptions.ConnectionFailedError:
return await self.async_step_service_problem()
except exceptions.BackOffError:
abort_reason = "backoff"
except exceptions.PairingError:
_LOGGER.exception("Authentication problem")
abort_reason = "invalid_auth"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
abort_reason = "unknown"
if abort_reason:
if self.pairing:
await self.pairing.close()
return self.async_abort(reason=abort_reason)
# Choose step depending on if PIN is required from user or not
if self.pairing.device_provides_pin:
return await self.async_step_pair_with_pin()
return await self.async_step_pair_no_pin()
async def async_step_pair_with_pin(self, user_input=None):
"""Handle pairing step where a PIN is required from the user."""
errors = {}
if user_input is not None:
try:
self.pairing.pin(user_input[CONF_PIN])
await self.pairing.finish()
self.credentials[self.protocol.value] = self.pairing.service.credentials
return await self.async_begin_pairing()
except exceptions.PairingError:
_LOGGER.exception("Authentication problem")
errors["base"] = "invalid_auth"
except AbortFlow:
raise
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
return self.async_show_form(
step_id="pair_with_pin",
data_schema=INPUT_PIN_SCHEMA,
errors=errors,
description_placeholders={"protocol": protocol_str(self.protocol)},
)
async def async_step_pair_no_pin(self, user_input=None):
"""Handle step where user has to enter a PIN on the device."""
if user_input is not None:
await self.pairing.finish()
if self.pairing.has_paired:
self.credentials[self.protocol.value] = self.pairing.service.credentials
return await self.async_begin_pairing()
await self.pairing.close()
return self.async_abort(reason="device_did_not_pair")
pin = randrange(1000, stop=10000)
self.pairing.pin(pin)
return self.async_show_form(
step_id="pair_no_pin",
description_placeholders={
"protocol": protocol_str(self.protocol),
"pin": pin,
},
)
async def async_step_service_problem(self, user_input=None):
"""Inform user that a service will not be added."""
if user_input is not None:
self.credentials[self.protocol.value] = None
return await self.async_begin_pairing()
return self.async_show_form(
step_id="service_problem",
description_placeholders={"protocol": protocol_str(self.protocol)},
)
def _async_get_entry(self, protocol, name, credentials, address):
if not is_valid_credentials(credentials):
return self.async_abort(reason="invalid_config")
data = {
CONF_PROTOCOL: protocol.value,
CONF_NAME: name,
CONF_CREDENTIALS: credentials,
CONF_ADDRESS: str(address),
}
self._abort_if_unique_id_configured(reload_on_update=False, updates=data)
return self.async_create_entry(title=name, data=data)
def _next_protocol_to_pair(self):
def _needs_pairing(protocol):
if self.atv.get_service(protocol) is None:
return False
return protocol.value not in self.credentials
for protocol in PROTOCOL_PRIORITY:
if _needs_pairing(protocol):
return protocol
return None
def _devices_str(self):
return ", ".join(
[
f"`{atv.name} ({atv.address})`"
for atv in self.scan_result
if atv.identifier not in self._async_current_ids()
]
)
def _prefill_identifier(self):
# Return identifier (address) of one device that has not been paired with
for atv in self.scan_result:
if atv.identifier not in self._async_current_ids():
return str(atv.address)
return ""
class AppleTVOptionsFlow(config_entries.OptionsFlow):
"""Handle Apple TV options."""
def __init__(self, config_entry):
"""Initialize Apple TV options flow."""
self.config_entry = config_entry
self.options = dict(config_entry.options)
async def async_step_init(self, user_input=None):
"""Manage the Apple TV options."""
if user_input is not None:
self.options[CONF_START_OFF] = user_input[CONF_START_OFF]
return self.async_create_entry(title="", data=self.options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(
{
vol.Optional(
CONF_START_OFF,
default=self.config_entry.options.get(
CONF_START_OFF, DEFAULT_START_OFF
),
): bool,
}
),
)
class DeviceNotFound(HomeAssistantError):
"""Error to indicate device could not be found."""
class DeviceAlreadyConfigured(HomeAssistantError):
"""Error to indicate device is already configured."""