Refactor Apple TV integration (#31952)

Co-authored-by: Franck Nijhof <git@frenck.dev>
pull/43860/head
Pierre Ståhl 2020-12-02 17:01:55 +01:00 committed by GitHub
parent 58648019c6
commit edb246d696
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1758 additions and 400 deletions

View File

@ -48,7 +48,9 @@ omit =
homeassistant/components/anel_pwrctrl/switch.py
homeassistant/components/anthemav/media_player.py
homeassistant/components/apcupsd/*
homeassistant/components/apple_tv/*
homeassistant/components/apple_tv/__init__.py
homeassistant/components/apple_tv/media_player.py
homeassistant/components/apple_tv/remote.py
homeassistant/components/aqualogic/*
homeassistant/components/aquostv/media_player.py
homeassistant/components/arcam_fmj/media_player.py

View File

@ -37,6 +37,7 @@ homeassistant/components/amcrest/* @pnbruckner
homeassistant/components/androidtv/* @JeffLIrion
homeassistant/components/apache_kafka/* @bachya
homeassistant/components/api/* @home-assistant/core
homeassistant/components/apple_tv/* @postlund
homeassistant/components/apprise/* @caronc
homeassistant/components/aprs/* @PhilRW
homeassistant/components/arcam_fmj/* @elupus

View File

@ -1,273 +1,363 @@
"""Support for Apple TV."""
"""The Apple TV integration."""
import asyncio
import logging
from typing import Sequence, TypeVar, Union
from random import randrange
from pyatv import AppleTVDevice, connect_to_apple_tv, scan_for_apple_tvs
from pyatv.exceptions import DeviceAuthenticationError
import voluptuous as vol
from pyatv import connect, exceptions, scan
from pyatv.const import Protocol
from homeassistant.components.discovery import SERVICE_APPLE_TV
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_NAME
from homeassistant.helpers import discovery
from homeassistant.components.media_player import DOMAIN as MP_DOMAIN
from homeassistant.components.remote import DOMAIN as REMOTE_DOMAIN
from homeassistant.const import (
CONF_ADDRESS,
CONF_NAME,
CONF_PROTOCOL,
EVENT_HOMEASSISTANT_STOP,
)
from homeassistant.core import callback
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.entity import Entity
from .const import CONF_CREDENTIALS, CONF_IDENTIFIER, CONF_START_OFF, DOMAIN
_LOGGER = logging.getLogger(__name__)
DOMAIN = "apple_tv"
SERVICE_SCAN = "apple_tv_scan"
SERVICE_AUTHENTICATE = "apple_tv_authenticate"
ATTR_ATV = "atv"
ATTR_POWER = "power"
CONF_LOGIN_ID = "login_id"
CONF_START_OFF = "start_off"
CONF_CREDENTIALS = "credentials"
DEFAULT_NAME = "Apple TV"
DATA_APPLE_TV = "data_apple_tv"
DATA_ENTITIES = "data_apple_tv_entities"
BACKOFF_TIME_UPPER_LIMIT = 300 # Five minutes
KEY_CONFIG = "apple_tv_configuring"
NOTIFICATION_TITLE = "Apple TV Notification"
NOTIFICATION_ID = "apple_tv_notification"
NOTIFICATION_AUTH_ID = "apple_tv_auth_notification"
NOTIFICATION_AUTH_TITLE = "Apple TV Authentication"
NOTIFICATION_SCAN_ID = "apple_tv_scan_notification"
NOTIFICATION_SCAN_TITLE = "Apple TV Scan"
SOURCE_REAUTH = "reauth"
T = TypeVar("T")
SIGNAL_CONNECTED = "apple_tv_connected"
SIGNAL_DISCONNECTED = "apple_tv_disconnected"
# This version of ensure_list interprets an empty dict as no value
def ensure_list(value: Union[T, Sequence[T]]) -> Sequence[T]:
"""Wrap value in list if it is not one."""
if value is None or (isinstance(value, dict) and not value):
return []
return value if isinstance(value, list) else [value]
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.All(
ensure_list,
[
vol.Schema(
{
vol.Required(CONF_HOST): cv.string,
vol.Required(CONF_LOGIN_ID): cv.string,
vol.Optional(CONF_CREDENTIALS): cv.string,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_START_OFF, default=False): cv.boolean,
}
)
],
)
},
extra=vol.ALLOW_EXTRA,
)
# Currently no attributes but it might change later
APPLE_TV_SCAN_SCHEMA = vol.Schema({})
APPLE_TV_AUTHENTICATE_SCHEMA = vol.Schema({ATTR_ENTITY_ID: cv.entity_ids})
def request_configuration(hass, config, atv, credentials):
"""Request configuration steps from the user."""
configurator = hass.components.configurator
async def configuration_callback(callback_data):
"""Handle the submitted configuration."""
pin = callback_data.get("pin")
try:
await atv.airplay.finish_authentication(pin)
hass.components.persistent_notification.async_create(
f"Authentication succeeded!<br /><br />"
f"Add the following to credentials: "
f"in your apple_tv configuration:<br /><br />{credentials}",
title=NOTIFICATION_AUTH_TITLE,
notification_id=NOTIFICATION_AUTH_ID,
)
except DeviceAuthenticationError as ex:
hass.components.persistent_notification.async_create(
f"Authentication failed! Did you enter correct PIN?<br /><br />Details: {ex}",
title=NOTIFICATION_AUTH_TITLE,
notification_id=NOTIFICATION_AUTH_ID,
)
hass.async_add_job(configurator.request_done, instance)
instance = configurator.request_config(
"Apple TV Authentication",
configuration_callback,
description="Please enter PIN code shown on screen.",
submit_caption="Confirm",
fields=[{"id": "pin", "name": "PIN Code", "type": "password"}],
)
async def scan_apple_tvs(hass):
"""Scan for devices and present a notification of the ones found."""
atvs = await scan_for_apple_tvs(hass.loop, timeout=3)
devices = []
for atv in atvs:
login_id = atv.login_id
if login_id is None:
login_id = "Home Sharing disabled"
devices.append(
f"Name: {atv.name}<br />Host: {atv.address}<br />Login ID: {login_id}"
)
if not devices:
devices = ["No device(s) found"]
found_devices = "<br /><br />".join(devices)
hass.components.persistent_notification.async_create(
f"The following devices were found:<br /><br />{found_devices}",
title=NOTIFICATION_SCAN_TITLE,
notification_id=NOTIFICATION_SCAN_ID,
)
PLATFORMS = [MP_DOMAIN, REMOTE_DOMAIN]
async def async_setup(hass, config):
"""Set up the Apple TV component."""
if DATA_APPLE_TV not in hass.data:
hass.data[DATA_APPLE_TV] = {}
"""Set up the Apple TV integration."""
return True
async def async_service_handler(service):
"""Handle service calls."""
entity_ids = service.data.get(ATTR_ENTITY_ID)
if service.service == SERVICE_SCAN:
hass.async_add_job(scan_apple_tvs, hass)
return
async def async_setup_entry(hass, entry):
"""Set up a config entry for Apple TV."""
manager = AppleTVManager(hass, entry)
hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager
if entity_ids:
devices = [
device
for device in hass.data[DATA_ENTITIES]
if device.entity_id in entity_ids
async def on_hass_stop(event):
"""Stop push updates when hass stops."""
await manager.disconnect()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop)
async def setup_platforms():
"""Set up platforms and initiate connection."""
await asyncio.gather(
*[
hass.config_entries.async_forward_entry_setup(entry, component)
for component in PLATFORMS
]
else:
devices = hass.data[DATA_ENTITIES]
for device in devices:
if service.service != SERVICE_AUTHENTICATE:
continue
atv = device.atv
credentials = await atv.airplay.generate_credentials()
await atv.airplay.load_credentials(credentials)
_LOGGER.debug("Generated new credentials: %s", credentials)
await atv.airplay.start_authentication()
hass.async_add_job(request_configuration, hass, config, atv, credentials)
async def atv_discovered(service, info):
"""Set up an Apple TV that was auto discovered."""
await _setup_atv(
hass,
config,
{
CONF_NAME: info["name"],
CONF_HOST: info["host"],
CONF_LOGIN_ID: info["properties"]["hG"],
CONF_START_OFF: False,
},
)
await manager.init()
discovery.async_listen(hass, SERVICE_APPLE_TV, atv_discovered)
tasks = [_setup_atv(hass, config, conf) for conf in config.get(DOMAIN, [])]
if tasks:
await asyncio.wait(tasks)
hass.services.async_register(
DOMAIN, SERVICE_SCAN, async_service_handler, schema=APPLE_TV_SCAN_SCHEMA
)
hass.services.async_register(
DOMAIN,
SERVICE_AUTHENTICATE,
async_service_handler,
schema=APPLE_TV_AUTHENTICATE_SCHEMA,
)
hass.async_create_task(setup_platforms())
return True
async def _setup_atv(hass, hass_config, atv_config):
"""Set up an Apple TV."""
name = atv_config.get(CONF_NAME)
host = atv_config.get(CONF_HOST)
login_id = atv_config.get(CONF_LOGIN_ID)
start_off = atv_config.get(CONF_START_OFF)
credentials = atv_config.get(CONF_CREDENTIALS)
if host in hass.data[DATA_APPLE_TV]:
return
details = AppleTVDevice(name, host, login_id)
session = async_get_clientsession(hass)
atv = connect_to_apple_tv(details, hass.loop, session=session)
if credentials:
await atv.airplay.load_credentials(credentials)
power = AppleTVPowerManager(hass, atv, start_off)
hass.data[DATA_APPLE_TV][host] = {ATTR_ATV: atv, ATTR_POWER: power}
hass.async_create_task(
discovery.async_load_platform(
hass, "media_player", DOMAIN, atv_config, hass_config
async def async_unload_entry(hass, entry):
"""Unload an Apple TV config entry."""
unload_ok = all(
await asyncio.gather(
*[
hass.config_entries.async_forward_entry_unload(entry, platform)
for platform in PLATFORMS
]
)
)
if unload_ok:
manager = hass.data[DOMAIN].pop(entry.unique_id)
await manager.disconnect()
hass.async_create_task(
discovery.async_load_platform(hass, "remote", DOMAIN, atv_config, hass_config)
)
return unload_ok
class AppleTVPowerManager:
"""Manager for global power management of an Apple TV.
class AppleTVEntity(Entity):
"""Device that sends commands to an Apple TV."""
An instance is used per device to share the same power state between
several platforms.
"""
def __init__(self, name, identifier, manager):
"""Initialize device."""
self.atv = None
self.manager = manager
self._name = name
self._identifier = identifier
def __init__(self, hass, atv, is_off):
"""Initialize power manager."""
self.hass = hass
self.atv = atv
self.listeners = []
self._is_on = not is_off
async def async_added_to_hass(self):
"""Handle when an entity is about to be added to Home Assistant."""
def init(self):
"""Initialize power management."""
if self._is_on:
self.atv.push_updater.start()
@callback
def _async_connected(atv):
"""Handle that a connection was made to a device."""
self.atv = atv
self.async_device_connected(atv)
self.async_write_ha_state()
@callback
def _async_disconnected():
"""Handle that a connection to a device was lost."""
self.async_device_disconnected()
self.atv = None
self.async_write_ha_state()
self.async_on_remove(
async_dispatcher_connect(
self.hass, f"{SIGNAL_CONNECTED}_{self._identifier}", _async_connected
)
)
self.async_on_remove(
async_dispatcher_connect(
self.hass,
f"{SIGNAL_DISCONNECTED}_{self._identifier}",
_async_disconnected,
)
)
def async_device_connected(self, atv):
"""Handle when connection is made to device."""
def async_device_disconnected(self):
"""Handle when connection was lost to device."""
@property
def turned_on(self):
"""Return true if device is on or off."""
return self._is_on
def device_info(self):
"""Return the device info."""
return {
"identifiers": {(DOMAIN, self._identifier)},
"manufacturer": "Apple",
"name": self.name,
}
def set_power_on(self, value):
"""Change if a device is on or off."""
if value != self._is_on:
self._is_on = value
if not self._is_on:
@property
def name(self):
"""Return the name of the device."""
return self._name
@property
def unique_id(self):
"""Return a unique ID."""
return self._identifier
@property
def should_poll(self):
"""No polling needed for Apple TV."""
return False
class AppleTVManager:
"""Connection and power manager for an Apple TV.
An instance is used per device to share the same power state between
several platforms. It also manages scanning and connection establishment
in case of problems.
"""
def __init__(self, hass, config_entry):
"""Initialize power manager."""
self.config_entry = config_entry
self.hass = hass
self.atv = None
self._is_on = not config_entry.options.get(CONF_START_OFF, False)
self._connection_attempts = 0
self._connection_was_lost = False
self._task = None
async def init(self):
"""Initialize power management."""
if self._is_on:
await self.connect()
def connection_lost(self, _):
"""Device was unexpectedly disconnected.
This is a callback function from pyatv.interface.DeviceListener.
"""
_LOGGER.warning('Connection lost to Apple TV "%s"', self.atv.name)
if self.atv:
self.atv.close()
self.atv = None
self._connection_was_lost = True
self._dispatch_send(SIGNAL_DISCONNECTED)
self._start_connect_loop()
def connection_closed(self):
"""Device connection was (intentionally) closed.
This is a callback function from pyatv.interface.DeviceListener.
"""
if self.atv:
self.atv.close()
self.atv = None
self._dispatch_send(SIGNAL_DISCONNECTED)
self._start_connect_loop()
async def connect(self):
"""Connect to device."""
self._is_on = True
self._start_connect_loop()
async def disconnect(self):
"""Disconnect from device."""
_LOGGER.debug("Disconnecting from device")
self._is_on = False
try:
if self.atv:
self.atv.push_updater.listener = None
self.atv.push_updater.stop()
else:
self.atv.push_updater.start()
self.atv.close()
self.atv = None
if self._task:
self._task.cancel()
self._task = None
except Exception: # pylint: disable=broad-except
_LOGGER.exception("An error occurred while disconnecting")
for listener in self.listeners:
self.hass.async_create_task(listener.async_update_ha_state())
def _start_connect_loop(self):
"""Start background connect loop to device."""
if not self._task and self.atv is None and self._is_on:
self._task = asyncio.create_task(self._connect_loop())
else:
_LOGGER.debug(
"Not starting connect loop (%s, %s)", self.atv is None, self._is_on
)
async def _connect_loop(self):
"""Connect loop background task function."""
_LOGGER.debug("Starting connect loop")
# Try to find device and connect as long as the user has said that
# we are allowed to connect and we are not already connected.
while self._is_on and self.atv is None:
try:
conf = await self._scan()
if conf:
await self._connect(conf)
except exceptions.AuthenticationError:
self._auth_problem()
break
except asyncio.CancelledError:
pass
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Failed to connect")
self.atv = None
if self.atv is None:
self._connection_attempts += 1
backoff = min(
randrange(2 ** self._connection_attempts), BACKOFF_TIME_UPPER_LIMIT
)
_LOGGER.debug("Reconnecting in %d seconds", backoff)
await asyncio.sleep(backoff)
_LOGGER.debug("Connect loop ended")
self._task = None
def _auth_problem(self):
"""Problem to authenticate occurred that needs intervention."""
_LOGGER.debug("Authentication error, reconfigure integration")
name = self.config_entry.data.get(CONF_NAME)
identifier = self.config_entry.unique_id
self.hass.components.persistent_notification.create(
"An irrecoverable connection problem occurred when connecting to "
f"`f{name}`. Please go to the Integrations page and reconfigure it",
title=NOTIFICATION_TITLE,
notification_id=NOTIFICATION_ID,
)
# Add to event queue as this function is called from a task being
# cancelled from disconnect
asyncio.create_task(self.disconnect())
self.hass.async_create_task(
self.hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_REAUTH},
data={CONF_NAME: name, CONF_IDENTIFIER: identifier},
)
)
async def _scan(self):
"""Try to find device by scanning for it."""
identifier = self.config_entry.unique_id
address = self.config_entry.data[CONF_ADDRESS]
protocol = Protocol(self.config_entry.data[CONF_PROTOCOL])
_LOGGER.debug("Discovering device %s", identifier)
atvs = await scan(
self.hass.loop, identifier=identifier, protocol=protocol, hosts=[address]
)
if atvs:
return atvs[0]
_LOGGER.debug(
"Failed to find device %s with address %s, trying to scan",
identifier,
address,
)
atvs = await scan(self.hass.loop, identifier=identifier, protocol=protocol)
if atvs:
return atvs[0]
_LOGGER.debug("Failed to find device %s, trying later", identifier)
return None
async def _connect(self, conf):
"""Connect to device."""
credentials = self.config_entry.data[CONF_CREDENTIALS]
session = async_get_clientsession(self.hass)
for protocol, creds in credentials.items():
conf.set_credentials(Protocol(int(protocol)), creds)
_LOGGER.debug("Connecting to device %s", self.config_entry.data[CONF_NAME])
self.atv = await connect(conf, self.hass.loop, session=session)
self.atv.listener = self
self._dispatch_send(SIGNAL_CONNECTED, self.atv)
self._address_updated(str(conf.address))
self._connection_attempts = 0
if self._connection_was_lost:
_LOGGER.info(
'Connection was re-established to Apple TV "%s"', self.atv.service.name
)
self._connection_was_lost = False
@property
def is_connecting(self):
"""Return true if connection is in progress."""
return self._task is not None
def _address_updated(self, address):
"""Update cached address in config entry."""
_LOGGER.debug("Changing address to %s", address)
self.hass.config_entries.async_update_entry(
self.config_entry, data={**self.config_entry.data, CONF_ADDRESS: address}
)
def _dispatch_send(self, signal, *args):
"""Dispatch a signal to all entities managed by this manager."""
async_dispatcher_send(
self.hass, f"{signal}_{self.config_entry.unique_id}", *args
)

View File

@ -0,0 +1,408 @@
"""Config flow for Apple TV integration."""
from ipaddress import ip_address
import logging
from random import randrange
from pyatv import exceptions, pair, scan
from pyatv.const import Protocol
from pyatv.convert import protocol_str
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.const import (
CONF_ADDRESS,
CONF_NAME,
CONF_PIN,
CONF_PROTOCOL,
CONF_TYPE,
)
from homeassistant.core import callback
from homeassistant.data_entry_flow import AbortFlow
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import CONF_CREDENTIALS, CONF_IDENTIFIER, CONF_START_OFF
from .const import DOMAIN # pylint: disable=unused-import
_LOGGER = logging.getLogger(__name__)
DEVICE_INPUT = "device_input"
INPUT_PIN_SCHEMA = vol.Schema({vol.Required(CONF_PIN, default=None): int})
DEFAULT_START_OFF = False
PROTOCOL_PRIORITY = [Protocol.MRP, Protocol.DMAP, Protocol.AirPlay]
async def device_scan(identifier, loop, cache=None):
"""Scan for a specific device using identifier as filter."""
def _filter_device(dev):
if identifier is None:
return True
if identifier == str(dev.address):
return True
if identifier == dev.name:
return True
return any([service.identifier == identifier for service in dev.services])
def _host_filter():
try:
return [ip_address(identifier)]
except ValueError:
return None
if cache:
matches = [atv for atv in cache if _filter_device(atv)]
if matches:
return cache, matches[0]
for hosts in [_host_filter(), None]:
scan_result = await scan(loop, timeout=3, hosts=hosts)
matches = [atv for atv in scan_result if _filter_device(atv)]
if matches:
return scan_result, matches[0]
return scan_result, None
def is_valid_credentials(credentials):
"""Verify that credentials are valid for establishing a connection."""
return (
credentials.get(Protocol.MRP.value) is not None
or credentials.get(Protocol.DMAP.value) is not None
)
class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Apple TV."""
VERSION = 1
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
@staticmethod
@callback
def async_get_options_flow(config_entry):
"""Get options flow for this handler."""
return AppleTVOptionsFlow(config_entry)
def __init__(self):
"""Initialize a new AppleTVConfigFlow."""
self.target_device = None
self.scan_result = None
self.atv = None
self.protocol = None
self.pairing = None
self.credentials = {} # Protocol -> credentials
async def async_step_reauth(self, info):
"""Handle initial step when updating invalid credentials."""
await self.async_set_unique_id(info[CONF_IDENTIFIER])
self.target_device = info[CONF_IDENTIFIER]
# pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167
self.context["title_placeholders"] = {"name": info[CONF_NAME]}
# pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167
self.context["identifier"] = self.unique_id
return await self.async_step_reconfigure()
async def async_step_reconfigure(self, user_input=None):
"""Inform user that reconfiguration is about to start."""
if user_input is not None:
return await self.async_find_device_wrapper(
self.async_begin_pairing, allow_exist=True
)
return self.async_show_form(step_id="reconfigure")
async def async_step_user(self, user_input=None):
"""Handle the initial step."""
# Be helpful to the user and look for devices
if self.scan_result is None:
self.scan_result, _ = await device_scan(None, self.hass.loop)
errors = {}
default_suggestion = self._prefill_identifier()
if user_input is not None:
self.target_device = user_input[DEVICE_INPUT]
try:
await self.async_find_device()
except DeviceNotFound:
errors["base"] = "no_devices_found"
except DeviceAlreadyConfigured:
errors["base"] = "already_configured"
except exceptions.NoServiceError:
errors["base"] = "no_usable_service"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
await self.async_set_unique_id(
self.atv.identifier, raise_on_progress=False
)
return await self.async_step_confirm()
return self.async_show_form(
step_id="user",
data_schema=vol.Schema(
{vol.Required(DEVICE_INPUT, default=default_suggestion): str}
),
errors=errors,
description_placeholders={"devices": self._devices_str()},
)
async def async_step_zeroconf(self, discovery_info):
"""Handle device found via zeroconf."""
service_type = discovery_info[CONF_TYPE]
properties = discovery_info["properties"]
if service_type == "_mediaremotetv._tcp.local.":
identifier = properties["UniqueIdentifier"]
name = properties["Name"]
elif service_type == "_touch-able._tcp.local.":
identifier = discovery_info["name"].split(".")[0]
name = properties["CtlN"]
else:
return self.async_abort(reason="unknown")
await self.async_set_unique_id(identifier)
self._abort_if_unique_id_configured()
# pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167
self.context["identifier"] = self.unique_id
self.context["title_placeholders"] = {"name": name}
self.target_device = identifier
return await self.async_find_device_wrapper(self.async_step_confirm)
async def async_find_device_wrapper(self, next_func, allow_exist=False):
"""Find a specific device and call another function when done.
This function will do error handling and bail out when an error
occurs.
"""
try:
await self.async_find_device(allow_exist)
except DeviceNotFound:
return self.async_abort(reason="no_devices_found")
except DeviceAlreadyConfigured:
return self.async_abort(reason="already_configured")
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
return self.async_abort(reason="unknown")
return await next_func()
async def async_find_device(self, allow_exist=False):
"""Scan for the selected device to discover services."""
self.scan_result, self.atv = await device_scan(
self.target_device, self.hass.loop, cache=self.scan_result
)
if not self.atv:
raise DeviceNotFound()
self.protocol = self.atv.main_service().protocol
if not allow_exist:
for identifier in self.atv.all_identifiers:
if identifier in self._async_current_ids():
raise DeviceAlreadyConfigured()
# If credentials were found, save them
for service in self.atv.services:
if service.credentials:
self.credentials[service.protocol.value] = service.credentials
async def async_step_confirm(self, user_input=None):
"""Handle user-confirmation of discovered node."""
if user_input is not None:
return await self.async_begin_pairing()
return self.async_show_form(
step_id="confirm", description_placeholders={"name": self.atv.name}
)
async def async_begin_pairing(self):
"""Start pairing process for the next available protocol."""
self.protocol = self._next_protocol_to_pair()
# Dispose previous pairing sessions
if self.pairing is not None:
await self.pairing.close()
self.pairing = None
# Any more protocols to pair? Else bail out here
if not self.protocol:
await self.async_set_unique_id(self.atv.main_service().identifier)
return self._async_get_entry(
self.atv.main_service().protocol,
self.atv.name,
self.credentials,
self.atv.address,
)
# Initiate the pairing process
abort_reason = None
session = async_get_clientsession(self.hass)
self.pairing = await pair(
self.atv, self.protocol, self.hass.loop, session=session
)
try:
await self.pairing.begin()
except exceptions.ConnectionFailedError:
return await self.async_step_service_problem()
except exceptions.BackOffError:
abort_reason = "backoff"
except exceptions.PairingError:
_LOGGER.exception("Authentication problem")
abort_reason = "invalid_auth"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
abort_reason = "unknown"
if abort_reason:
if self.pairing:
await self.pairing.close()
return self.async_abort(reason=abort_reason)
# Choose step depending on if PIN is required from user or not
if self.pairing.device_provides_pin:
return await self.async_step_pair_with_pin()
return await self.async_step_pair_no_pin()
async def async_step_pair_with_pin(self, user_input=None):
"""Handle pairing step where a PIN is required from the user."""
errors = {}
if user_input is not None:
try:
self.pairing.pin(user_input[CONF_PIN])
await self.pairing.finish()
self.credentials[self.protocol.value] = self.pairing.service.credentials
return await self.async_begin_pairing()
except exceptions.PairingError:
_LOGGER.exception("Authentication problem")
errors["base"] = "invalid_auth"
except AbortFlow:
raise
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
return self.async_show_form(
step_id="pair_with_pin",
data_schema=INPUT_PIN_SCHEMA,
errors=errors,
description_placeholders={"protocol": protocol_str(self.protocol)},
)
async def async_step_pair_no_pin(self, user_input=None):
"""Handle step where user has to enter a PIN on the device."""
if user_input is not None:
await self.pairing.finish()
if self.pairing.has_paired:
self.credentials[self.protocol.value] = self.pairing.service.credentials
return await self.async_begin_pairing()
await self.pairing.close()
return self.async_abort(reason="device_did_not_pair")
pin = randrange(1000, stop=10000)
self.pairing.pin(pin)
return self.async_show_form(
step_id="pair_no_pin",
description_placeholders={
"protocol": protocol_str(self.protocol),
"pin": pin,
},
)
async def async_step_service_problem(self, user_input=None):
"""Inform user that a service will not be added."""
if user_input is not None:
self.credentials[self.protocol.value] = None
return await self.async_begin_pairing()
return self.async_show_form(
step_id="service_problem",
description_placeholders={"protocol": protocol_str(self.protocol)},
)
def _async_get_entry(self, protocol, name, credentials, address):
if not is_valid_credentials(credentials):
return self.async_abort(reason="invalid_config")
data = {
CONF_PROTOCOL: protocol.value,
CONF_NAME: name,
CONF_CREDENTIALS: credentials,
CONF_ADDRESS: str(address),
}
self._abort_if_unique_id_configured(reload_on_update=False, updates=data)
return self.async_create_entry(title=name, data=data)
def _next_protocol_to_pair(self):
def _needs_pairing(protocol):
if self.atv.get_service(protocol) is None:
return False
return protocol.value not in self.credentials
for protocol in PROTOCOL_PRIORITY:
if _needs_pairing(protocol):
return protocol
return None
def _devices_str(self):
return ", ".join(
[
f"`{atv.name} ({atv.address})`"
for atv in self.scan_result
if atv.identifier not in self._async_current_ids()
]
)
def _prefill_identifier(self):
# Return identifier (address) of one device that has not been paired with
for atv in self.scan_result:
if atv.identifier not in self._async_current_ids():
return str(atv.address)
return ""
class AppleTVOptionsFlow(config_entries.OptionsFlow):
"""Handle Apple TV options."""
def __init__(self, config_entry):
"""Initialize Apple TV options flow."""
self.config_entry = config_entry
self.options = dict(config_entry.options)
async def async_step_init(self, user_input=None):
"""Manage the Apple TV options."""
if user_input is not None:
self.options[CONF_START_OFF] = user_input[CONF_START_OFF]
return self.async_create_entry(title="", data=self.options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(
{
vol.Optional(
CONF_START_OFF,
default=self.config_entry.options.get(
CONF_START_OFF, DEFAULT_START_OFF
),
): bool,
}
),
)
class DeviceNotFound(HomeAssistantError):
"""Error to indicate device could not be found."""
class DeviceAlreadyConfigured(HomeAssistantError):
"""Error to indicate device is already configured."""

View File

@ -0,0 +1,11 @@
"""Constants for the Apple TV integration."""
DOMAIN = "apple_tv"
CONF_IDENTIFIER = "identifier"
CONF_CREDENTIALS = "credentials"
CONF_CREDENTIALS_MRP = "mrp"
CONF_CREDENTIALS_DMAP = "dmap"
CONF_CREDENTIALS_AIRPLAY = "airplay"
CONF_START_OFF = "start_off"

View File

@ -1,9 +1,17 @@
{
"domain": "apple_tv",
"name": "Apple TV",
"config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/apple_tv",
"requirements": ["pyatv==0.3.13"],
"dependencies": ["configurator"],
"requirements": [
"pyatv==0.7.3"
],
"zeroconf": [
"_mediaremotetv._tcp.local.",
"_touch-able._tcp.local."
],
"after_dependencies": ["discovery"],
"codeowners": []
"codeowners": [
"@postlund"
]
}

View File

@ -1,7 +1,7 @@
"""Support for Apple TV media player."""
import logging
import pyatv.const as atv_const
from pyatv.const import DeviceState, MediaType
from homeassistant.components.media_player import MediaPlayerEntity
from homeassistant.components.media_player.const import (
@ -19,9 +19,7 @@ from homeassistant.components.media_player.const import (
SUPPORT_TURN_ON,
)
from homeassistant.const import (
CONF_HOST,
CONF_NAME,
EVENT_HOMEASSISTANT_STOP,
STATE_IDLE,
STATE_OFF,
STATE_PAUSED,
@ -31,10 +29,13 @@ from homeassistant.const import (
from homeassistant.core import callback
import homeassistant.util.dt as dt_util
from . import ATTR_ATV, ATTR_POWER, DATA_APPLE_TV, DATA_ENTITIES
from . import AppleTVEntity
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
PARALLEL_UPDATES = 0
SUPPORT_APPLE_TV = (
SUPPORT_TURN_ON
| SUPPORT_TURN_OFF
@ -48,108 +49,61 @@ SUPPORT_APPLE_TV = (
)
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
"""Set up the Apple TV platform."""
if not discovery_info:
return
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Load Apple TV media player based on a config entry."""
name = config_entry.data[CONF_NAME]
manager = hass.data[DOMAIN][config_entry.unique_id]
async_add_entities([AppleTvMediaPlayer(name, config_entry.unique_id, manager)])
# Manage entity cache for service handler
if DATA_ENTITIES not in hass.data:
hass.data[DATA_ENTITIES] = []
name = discovery_info[CONF_NAME]
host = discovery_info[CONF_HOST]
atv = hass.data[DATA_APPLE_TV][host][ATTR_ATV]
power = hass.data[DATA_APPLE_TV][host][ATTR_POWER]
entity = AppleTvDevice(atv, name, power)
class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
"""Representation of an Apple TV media player."""
def __init__(self, name, identifier, manager, **kwargs):
"""Initialize the Apple TV media player."""
super().__init__(name, identifier, manager, **kwargs)
self._playing = None
@callback
def on_hass_stop(event):
"""Stop push updates when hass stops."""
atv.push_updater.stop()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop)
if entity not in hass.data[DATA_ENTITIES]:
hass.data[DATA_ENTITIES].append(entity)
async_add_entities([entity])
class AppleTvDevice(MediaPlayerEntity):
"""Representation of an Apple TV device."""
def __init__(self, atv, name, power):
"""Initialize the Apple TV device."""
self.atv = atv
self._name = name
self._playing = None
self._power = power
self._power.listeners.append(self)
def async_device_connected(self, atv):
"""Handle when connection is made to device."""
self.atv.push_updater.listener = self
self.atv.push_updater.start()
async def async_added_to_hass(self):
"""Handle when an entity is about to be added to Home Assistant."""
self._power.init()
@property
def name(self):
"""Return the name of the device."""
return self._name
@property
def unique_id(self):
"""Return a unique ID."""
return self.atv.metadata.device_id
@property
def should_poll(self):
"""No polling needed."""
return False
@callback
def async_device_disconnected(self):
"""Handle when connection was lost to device."""
self.atv.push_updater.stop()
self.atv.push_updater.listener = None
@property
def state(self):
"""Return the state of the device."""
if not self._power.turned_on:
if self.manager.is_connecting:
return None
if self.atv is None:
return STATE_OFF
if self._playing:
state = self._playing.play_state
if state in (
atv_const.PLAY_STATE_IDLE,
atv_const.PLAY_STATE_NO_MEDIA,
atv_const.PLAY_STATE_LOADING,
):
state = self._playing.device_state
if state in (DeviceState.Idle, DeviceState.Loading):
return STATE_IDLE
if state == atv_const.PLAY_STATE_PLAYING:
if state == DeviceState.Playing:
return STATE_PLAYING
if state in (
atv_const.PLAY_STATE_PAUSED,
atv_const.PLAY_STATE_FAST_FORWARD,
atv_const.PLAY_STATE_FAST_BACKWARD,
atv_const.PLAY_STATE_STOPPED,
):
# Catch fast forward/backward here so "play" is default action
if state in (DeviceState.Paused, DeviceState.Seeking, DeviceState.Stopped):
return STATE_PAUSED
return STATE_STANDBY # Bad or unknown state?
return None
@callback
def playstatus_update(self, updater, playing):
def playstatus_update(self, _, playing):
"""Print what is currently playing when it changes."""
self._playing = playing
self.async_write_ha_state()
@callback
def playstatus_error(self, updater, exception):
def playstatus_error(self, _, exception):
"""Inform about an error and restart push updates."""
_LOGGER.warning("A %s error occurred: %s", exception.__class__, exception)
# This will wait 10 seconds before restarting push updates. If the
# connection continues to fail, it will flood the log (every 10
# seconds) until it succeeds. A better approach should probably be
# implemented here later.
updater.start(initial_delay=10)
self._playing = None
self.async_write_ha_state()
@ -157,50 +111,53 @@ class AppleTvDevice(MediaPlayerEntity):
def media_content_type(self):
"""Content type of current playing media."""
if self._playing:
media_type = self._playing.media_type
if media_type == atv_const.MEDIA_TYPE_VIDEO:
return MEDIA_TYPE_VIDEO
if media_type == atv_const.MEDIA_TYPE_MUSIC:
return MEDIA_TYPE_MUSIC
if media_type == atv_const.MEDIA_TYPE_TV:
return MEDIA_TYPE_TVSHOW
return {
MediaType.Video: MEDIA_TYPE_VIDEO,
MediaType.Music: MEDIA_TYPE_MUSIC,
MediaType.TV: MEDIA_TYPE_TVSHOW,
}.get(self._playing.media_type)
return None
@property
def media_duration(self):
"""Duration of current playing media in seconds."""
if self._playing:
return self._playing.total_time
return None
@property
def media_position(self):
"""Position of current playing media in seconds."""
if self._playing:
return self._playing.position
return None
@property
def media_position_updated_at(self):
"""Last valid time of media position."""
state = self.state
if state in (STATE_PLAYING, STATE_PAUSED):
if self.state in (STATE_PLAYING, STATE_PAUSED):
return dt_util.utcnow()
return None
async def async_play_media(self, media_type, media_id, **kwargs):
"""Send the play_media command to the media player."""
await self.atv.airplay.play_url(media_id)
await self.atv.stream.play_url(media_id)
@property
def media_image_hash(self):
"""Hash value for media image."""
state = self.state
if self._playing and state not in [STATE_OFF, STATE_IDLE]:
return self._playing.hash
if self._playing and state not in [None, STATE_OFF, STATE_IDLE]:
return self.atv.metadata.artwork_id
return None
async def async_get_media_image(self):
"""Fetch media image of current playing image."""
state = self.state
if self._playing and state not in [STATE_OFF, STATE_IDLE]:
return (await self.atv.metadata.artwork()), "image/png"
artwork = await self.atv.metadata.artwork()
if artwork:
return artwork.bytes, artwork.mimetype
return None, None
@ -208,12 +165,8 @@ class AppleTvDevice(MediaPlayerEntity):
def media_title(self):
"""Title of current playing media."""
if self._playing:
if self.state == STATE_IDLE:
return "Nothing playing"
title = self._playing.title
return title if title else "No title"
return f"Establishing a connection to {self._name}..."
return self._playing.title
return None
@property
def supported_features(self):
@ -222,22 +175,22 @@ class AppleTvDevice(MediaPlayerEntity):
async def async_turn_on(self):
"""Turn the media player on."""
self._power.set_power_on(True)
await self.manager.connect()
async def async_turn_off(self):
"""Turn the media player off."""
self._playing = None
self._power.set_power_on(False)
await self.manager.disconnect()
async def async_media_play_pause(self):
"""Pause media on media player."""
if not self._playing:
return
state = self.state
if state == STATE_PAUSED:
await self.atv.remote_control.play()
elif state == STATE_PLAYING:
await self.atv.remote_control.pause()
if self._playing:
state = self.state
if state == STATE_PAUSED:
await self.atv.remote_control.play()
elif state == STATE_PLAYING:
await self.atv.remote_control.pause()
return None
async def async_media_play(self):
"""Play media."""

View File

@ -1,46 +1,32 @@
"""Remote control support for Apple TV."""
from homeassistant.components import remote
from homeassistant.const import CONF_HOST, CONF_NAME
from . import ATTR_ATV, ATTR_POWER, DATA_APPLE_TV
import logging
from homeassistant.components.remote import RemoteEntity
from homeassistant.const import CONF_NAME
from . import AppleTVEntity
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
PARALLEL_UPDATES = 0
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
"""Set up the Apple TV remote platform."""
if not discovery_info:
return
name = discovery_info[CONF_NAME]
host = discovery_info[CONF_HOST]
atv = hass.data[DATA_APPLE_TV][host][ATTR_ATV]
power = hass.data[DATA_APPLE_TV][host][ATTR_POWER]
async_add_entities([AppleTVRemote(atv, power, name)])
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Load Apple TV remote based on a config entry."""
name = config_entry.data[CONF_NAME]
manager = hass.data[DOMAIN][config_entry.unique_id]
async_add_entities([AppleTVRemote(name, config_entry.unique_id, manager)])
class AppleTVRemote(remote.RemoteEntity):
class AppleTVRemote(AppleTVEntity, RemoteEntity):
"""Device that sends commands to an Apple TV."""
def __init__(self, atv, power, name):
"""Initialize device."""
self._atv = atv
self._name = name
self._power = power
self._power.listeners.append(self)
@property
def name(self):
"""Return the name of the device."""
return self._name
@property
def unique_id(self):
"""Return a unique ID."""
return self._atv.metadata.device_id
@property
def is_on(self):
"""Return true if device is on."""
return self._power.turned_on
return self.atv is not None
@property
def should_poll(self):
@ -48,23 +34,21 @@ class AppleTVRemote(remote.RemoteEntity):
return False
async def async_turn_on(self, **kwargs):
"""Turn the device on.
This method is a coroutine.
"""
self._power.set_power_on(True)
"""Turn the device on."""
await self.manager.connect()
async def async_turn_off(self, **kwargs):
"""Turn the device off.
This method is a coroutine.
"""
self._power.set_power_on(False)
"""Turn the device off."""
await self.manager.disconnect()
async def async_send_command(self, command, **kwargs):
"""Send a command to one device."""
if not self.is_on:
_LOGGER.error("Unable to send commands, not connected to %s", self._name)
return
for single_command in command:
if not hasattr(self._atv.remote_control, single_command):
if not hasattr(self.atv.remote_control, single_command):
continue
await getattr(self._atv.remote_control, single_command)()
await getattr(self.atv.remote_control, single_command)()

View File

@ -1,8 +0,0 @@
apple_tv_authenticate:
description: Start AirPlay device authentication.
fields:
entity_id:
description: Name(s) of entities to authenticate with.
example: media_player.apple_tv
apple_tv_scan:
description: Scan for Apple TV devices.

View File

@ -0,0 +1,64 @@
{
"title": "Apple TV",
"config": {
"flow_title": "Apple TV: {name}",
"step": {
"user": {
"title": "Setup a new Apple TV",
"description": "Start by entering the device name (e.g. Kitchen or Bedroom) or IP address of the Apple TV you want to add. If any devices were automatically found on your network, they are shown below.\n\nIf you cannot see your device or experience any issues, try specifying the device IP address.\n\n{devices}",
"data": {
"device_input": "Device"
}
},
"reconfigure": {
"title": "Device reconfiguration",
"description": "This Apple TV is experiencing some connection difficulties and must be reconfigured."
},
"pair_with_pin": {
"title": "Pairing",
"description": "Pairing is required for the `{protocol}` protocol. Please enter the PIN code displayed on screen. Leading zeros shall be omitted, i.e. enter 123 if the displayed code is 0123.",
"data": {
"pin": "[%key:common::config_flow::data::pin%]"
}
},
"pair_no_pin": {
"title": "Pairing",
"description": "Pairing is required for the `{protocol}` service. Please enter PIN {pin} on your Apple TV to continue."
},
"service_problem": {
"title": "Failed to add service",
"description": "A problem occurred while pairing protocol `{protocol}`. It will be ignored."
},
"confirm": {
"title": "Confirm adding Apple TV",
"description": "You are about to add the Apple TV named `{name}` to Home Assistant.\n\n**To complete the process, you may have to enter multiple PIN codes.**\n\nPlease note that you will *not* be able to power off your Apple TV with this integration. Only the media player in Home Assistant will turn off!"
}
},
"error": {
"no_devices_found": "[%key:common::config_flow::abort::no_devices_found%]",
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
"no_usable_service": "A device was found but could not identify any way to establish a connection to it. If you keep seeing this message, try specifying its IP address or restarting your Apple TV.",
"unknown": "[%key:common::config_flow::error::unknown%]",
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]"
},
"abort": {
"no_devices_found": "[%key:common::config_flow::abort::no_devices_found%]",
"already_configured_device": "[%key:common::config_flow::abort::already_configured_device%]",
"device_did_not_pair": "No attempt to finish pairing process was made from the device.",
"backoff": "Device does not accept pairing reqests at this time (you might have entered an invalid PIN code too many times), try again later.",
"invalid_config": "The configuration for this device is incomplete. Please try adding it again.",
"already_in_progress": "[%key:common::config_flow::abort::already_in_progress%]",
"unknown": "[%key:common::config_flow::error::unknown%]"
}
},
"options": {
"step": {
"init": {
"description": "Configure general device settings",
"data": {
"start_off": "Do not turn device on when starting Home Assistant"
}
}
}
}
}

View File

@ -0,0 +1,64 @@
{
"title": "Apple TV",
"config": {
"flow_title": "Apple TV: {name}",
"step": {
"user": {
"title": "Setup a new Apple TV",
"description": "Start by entering the device name (e.g. Kitchen or Bedroom) or IP address of the Apple TV you want to add. If any devices were automatically found on your network, they are shown below.\n\nIf you cannot see your device or experience any issues, try specifying the device IP address.\n\n{devices}",
"data": {
"device_input": "Device"
}
},
"reconfigure": {
"title": "Device reconfiguration",
"description": "This Apple TV is experiencing some connection difficulties and must be reconfigured."
},
"pair_with_pin": {
"title": "Pairing",
"description": "Pairing is required for the `{protocol}` protocol. Please enter the PIN code displayed on screen. Leading zeros shall be omitted, i.e. enter 123 if the displayed code is 0123.",
"data": {
"pin": "PIN Code"
}
},
"pair_no_pin": {
"title": "Pairing",
"description": "Pairing is required for the `{protocol}` service. Please enter PIN {pin} on your Apple TV to continue."
},
"service_problem": {
"title": "Failed to add service",
"description": "A problem occurred while pairing protocol `{protocol}`. It will be ignored."
},
"confirm": {
"title": "Confirm adding Apple TV",
"description": "You are about to add the Apple TV named `{name}` to Home Assistant.\n\n**To complete the process, you may have to enter multiple PIN codes.**\n\nPlease note that you will *not* be able to power off your Apple TV with this integration. Only the media player in Home Assistant will turn off!"
}
},
"error": {
"no_devices_found": "No devices found on the network",
"already_configured": "Device is already configured",
"no_usable_service": "A device was found but could not identify any way to establish a connection to it. If you keep seeing this message, try specifying its IP address or restarting your Apple TV.",
"unknown": "Unexpected error",
"invalid_auth": "Invalid authentication"
},
"abort": {
"no_devices_found": "No devices found on the network",
"already_configured_device": "Device is already configured",
"device_did_not_pair": "No attempt to finish pairing process was made from the device.",
"backoff": "Device does not accept pairing reqests at this time (you might have entered an invalid PIN code too many times), try again later.",
"invalid_config": "The configuration for this device is incomplete. Please try adding it again.",
"already_in_progress": "Configuration flow is already in progress",
"unknown": "Unexpected error"
}
},
"options": {
"step": {
"init": {
"description": "Configure general device settings",
"data": {
"start_off": "Do not turn device on when starting Home Assistant"
}
}
}
}
}

View File

@ -18,6 +18,7 @@ FLOWS = [
"almond",
"ambiclimate",
"ambient_station",
"apple_tv",
"arcam_fmj",
"atag",
"august",

View File

@ -90,6 +90,11 @@ ZEROCONF = {
"domain": "ipp"
}
],
"_mediaremotetv._tcp.local.": [
{
"domain": "apple_tv"
}
],
"_miio._udp.local.": [
{
"domain": "xiaomi_aqara"
@ -129,6 +134,11 @@ ZEROCONF = {
"name": "smappee2*"
}
],
"_touch-able._tcp.local.": [
{
"domain": "apple_tv"
}
],
"_viziocast._tcp.local.": [
{
"domain": "vizio"

View File

@ -1283,7 +1283,7 @@ pyatmo==4.2.1
pyatome==0.1.1
# homeassistant.components.apple_tv
pyatv==0.3.13
pyatv==0.7.3
# homeassistant.components.bbox
pybbox==0.0.5-alpha

View File

@ -648,6 +648,9 @@ pyatag==0.3.4.4
# homeassistant.components.netatmo
pyatmo==4.2.1
# homeassistant.components.apple_tv
pyatv==0.7.3
# homeassistant.components.blackbird
pyblackbird==0.5

View File

@ -0,0 +1,5 @@
"""Tests for Apple TV."""
import pytest
# Make asserts in the common module display differences
pytest.register_assert_rewrite("tests.components.apple_tv.common")

View File

@ -0,0 +1,49 @@
"""Test code shared between test files."""
from pyatv import conf, interface
from pyatv.const import Protocol
class MockPairingHandler(interface.PairingHandler):
"""Mock for PairingHandler in pyatv."""
def __init__(self, *args):
"""Initialize a new MockPairingHandler."""
super().__init__(*args)
self.pin_code = None
self.paired = False
self.always_fail = False
def pin(self, pin):
"""Pin code used for pairing."""
self.pin_code = pin
self.paired = False
@property
def device_provides_pin(self):
"""Return True if remote device presents PIN code, else False."""
return self.service.protocol in [Protocol.MRP, Protocol.AirPlay]
@property
def has_paired(self):
"""If a successful pairing has been performed.
The value will be reset when stop() is called.
"""
return not self.always_fail and self.paired
async def begin(self):
"""Start pairing process."""
async def finish(self):
"""Stop pairing process."""
self.paired = True
self.service.credentials = self.service.protocol.name.lower() + "_creds"
def create_conf(name, address, *services):
"""Create an Apple TV configuration."""
atv = conf.AppleTV(name, address)
for service in services:
atv.add_service(service)
return atv

View File

@ -0,0 +1,131 @@
"""Fixtures for component."""
from pyatv import conf, net
import pytest
from .common import MockPairingHandler, create_conf
from tests.async_mock import patch
@pytest.fixture(autouse=True, name="mock_scan")
def mock_scan_fixture():
"""Mock pyatv.scan."""
with patch("homeassistant.components.apple_tv.config_flow.scan") as mock_scan:
async def _scan(loop, timeout=5, identifier=None, protocol=None, hosts=None):
if not mock_scan.hosts:
mock_scan.hosts = hosts
return mock_scan.result
mock_scan.result = []
mock_scan.hosts = None
mock_scan.side_effect = _scan
yield mock_scan
@pytest.fixture(name="dmap_pin")
def dmap_pin_fixture():
"""Mock pyatv.scan."""
with patch("homeassistant.components.apple_tv.config_flow.randrange") as mock_pin:
mock_pin.side_effect = lambda start, stop: 1111
yield mock_pin
@pytest.fixture
def pairing():
"""Mock pyatv.scan."""
with patch("homeassistant.components.apple_tv.config_flow.pair") as mock_pair:
async def _pair(config, protocol, loop, session=None, **kwargs):
handler = MockPairingHandler(
await net.create_session(session), config.get_service(protocol)
)
handler.always_fail = mock_pair.always_fail
return handler
mock_pair.always_fail = False
mock_pair.side_effect = _pair
yield mock_pair
@pytest.fixture
def pairing_mock():
"""Mock pyatv.scan."""
with patch("homeassistant.components.apple_tv.config_flow.pair") as mock_pair:
async def _pair(config, protocol, loop, session=None, **kwargs):
return mock_pair
async def _begin():
pass
async def _close():
pass
mock_pair.close.side_effect = _close
mock_pair.begin.side_effect = _begin
mock_pair.pin = lambda pin: None
mock_pair.side_effect = _pair
yield mock_pair
@pytest.fixture
def full_device(mock_scan, dmap_pin):
"""Mock pyatv.scan."""
mock_scan.result.append(
create_conf(
"127.0.0.1",
"MRP Device",
conf.MrpService("mrpid", 5555),
conf.DmapService("dmapid", None, port=6666),
conf.AirPlayService("airplayid", port=7777),
)
)
yield mock_scan
@pytest.fixture
def mrp_device(mock_scan):
"""Mock pyatv.scan."""
mock_scan.result.append(
create_conf("127.0.0.1", "MRP Device", conf.MrpService("mrpid", 5555))
)
yield mock_scan
@pytest.fixture
def dmap_device(mock_scan):
"""Mock pyatv.scan."""
mock_scan.result.append(
create_conf(
"127.0.0.1",
"DMAP Device",
conf.DmapService("dmapid", None, port=6666),
)
)
yield mock_scan
@pytest.fixture
def dmap_device_with_credentials(mock_scan):
"""Mock pyatv.scan."""
mock_scan.result.append(
create_conf(
"127.0.0.1",
"DMAP Device",
conf.DmapService("dmapid", "dummy_creds", port=6666),
)
)
yield mock_scan
@pytest.fixture
def airplay_device(mock_scan):
"""Mock pyatv.scan."""
mock_scan.result.append(
create_conf(
"127.0.0.1", "AirPlay Device", conf.AirPlayService("airplayid", port=7777)
)
)
yield mock_scan

View File

@ -0,0 +1,582 @@
"""Test config flow."""
from pyatv import exceptions
from pyatv.const import Protocol
import pytest
from homeassistant import config_entries, data_entry_flow
from homeassistant.components.apple_tv.const import CONF_START_OFF, DOMAIN
from tests.async_mock import patch
from tests.common import MockConfigEntry
DMAP_SERVICE = {
"type": "_touch-able._tcp.local.",
"name": "dmapid.something",
"properties": {"CtlN": "Apple TV"},
}
@pytest.fixture(autouse=True)
def mock_setup_entry():
"""Mock setting up a config entry."""
with patch(
"homeassistant.components.apple_tv.async_setup_entry", return_value=True
):
yield
# User Flows
async def test_user_input_device_not_found(hass, mrp_device):
"""Test when user specifies a non-existing device."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["description_placeholders"] == {"devices": "`MRP Device (127.0.0.1)`"}
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "none"},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["errors"] == {"base": "no_devices_found"}
async def test_user_input_unexpected_error(hass, mock_scan):
"""Test that unexpected error yields an error message."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
mock_scan.side_effect = Exception
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "dummy"},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["errors"] == {"base": "unknown"}
async def test_user_adds_full_device(hass, full_device, pairing):
"""Test adding device with all services."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["errors"] == {}
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "MRP Device"},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["description_placeholders"] == {"name": "MRP Device"}
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result3["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result3["description_placeholders"] == {"protocol": "MRP"}
result4 = await hass.config_entries.flow.async_configure(
result["flow_id"], {"pin": 1111}
)
assert result4["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result4["description_placeholders"] == {"protocol": "DMAP", "pin": 1111}
result5 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result5["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result5["description_placeholders"] == {"protocol": "AirPlay"}
result6 = await hass.config_entries.flow.async_configure(
result["flow_id"], {"pin": 1234}
)
assert result6["type"] == "create_entry"
assert result6["data"] == {
"address": "127.0.0.1",
"credentials": {
Protocol.DMAP.value: "dmap_creds",
Protocol.MRP.value: "mrp_creds",
Protocol.AirPlay.value: "airplay_creds",
},
"name": "MRP Device",
"protocol": Protocol.MRP.value,
}
async def test_user_adds_dmap_device(hass, dmap_device, dmap_pin, pairing):
"""Test adding device with only DMAP service."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "DMAP Device"},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["description_placeholders"] == {"name": "DMAP Device"}
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result3["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result3["description_placeholders"] == {"pin": 1111, "protocol": "DMAP"}
result6 = await hass.config_entries.flow.async_configure(
result["flow_id"], {"pin": 1234}
)
assert result6["type"] == "create_entry"
assert result6["data"] == {
"address": "127.0.0.1",
"credentials": {Protocol.DMAP.value: "dmap_creds"},
"name": "DMAP Device",
"protocol": Protocol.DMAP.value,
}
async def test_user_adds_dmap_device_failed(hass, dmap_device, dmap_pin, pairing):
"""Test adding DMAP device where remote device did not attempt to pair."""
pairing.always_fail = True
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "DMAP Device"},
)
await hass.config_entries.flow.async_configure(result["flow_id"], {})
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result2["reason"] == "device_did_not_pair"
async def test_user_adds_device_with_credentials(hass, dmap_device_with_credentials):
"""Test adding DMAP device with existing credentials (home sharing)."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "DMAP Device"},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["description_placeholders"] == {"name": "DMAP Device"}
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result3["type"] == "create_entry"
assert result3["data"] == {
"address": "127.0.0.1",
"credentials": {Protocol.DMAP.value: "dummy_creds"},
"name": "DMAP Device",
"protocol": Protocol.DMAP.value,
}
async def test_user_adds_device_with_ip_filter(
hass, dmap_device_with_credentials, mock_scan
):
"""Test add device filtering by IP."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "127.0.0.1"},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["description_placeholders"] == {"name": "DMAP Device"}
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result3["type"] == "create_entry"
assert result3["data"] == {
"address": "127.0.0.1",
"credentials": {Protocol.DMAP.value: "dummy_creds"},
"name": "DMAP Device",
"protocol": Protocol.DMAP.value,
}
async def test_user_adds_device_by_ip_uses_unicast_scan(hass, mock_scan):
"""Test add device by IP-address, verify unicast scan is used."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "127.0.0.1"},
)
assert str(mock_scan.hosts[0]) == "127.0.0.1"
async def test_user_adds_existing_device(hass, mrp_device):
"""Test that it is not possible to add existing device."""
MockConfigEntry(domain="apple_tv", unique_id="mrpid").add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "127.0.0.1"},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["errors"] == {"base": "already_configured"}
async def test_user_adds_unusable_device(hass, airplay_device):
"""Test that it is not possible to add pure AirPlay device."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "AirPlay Device"},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["errors"] == {"base": "no_usable_service"}
async def test_user_connection_failed(hass, mrp_device, pairing_mock):
"""Test error message when connection to device fails."""
pairing_mock.begin.side_effect = exceptions.ConnectionFailedError
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "MRP Device"},
)
await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result2["reason"] == "invalid_config"
async def test_user_start_pair_error_failed(hass, mrp_device, pairing_mock):
"""Test initiating pairing fails."""
pairing_mock.begin.side_effect = exceptions.PairingError
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "MRP Device"},
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result2["reason"] == "invalid_auth"
async def test_user_pair_invalid_pin(hass, mrp_device, pairing_mock):
"""Test pairing with invalid pin."""
pairing_mock.finish.side_effect = exceptions.PairingError
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "MRP Device"},
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"pin": 1111},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["errors"] == {"base": "invalid_auth"}
async def test_user_pair_unexpected_error(hass, mrp_device, pairing_mock):
"""Test unexpected error when entering PIN code."""
pairing_mock.finish.side_effect = Exception
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "MRP Device"},
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"pin": 1111},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["errors"] == {"base": "unknown"}
async def test_user_pair_backoff_error(hass, mrp_device, pairing_mock):
"""Test that backoff error is displayed in case device requests it."""
pairing_mock.begin.side_effect = exceptions.BackOffError
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "MRP Device"},
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result2["reason"] == "backoff"
async def test_user_pair_begin_unexpected_error(hass, mrp_device, pairing_mock):
"""Test unexpected error during start of pairing."""
pairing_mock.begin.side_effect = Exception
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
await hass.config_entries.flow.async_configure(
result["flow_id"],
{"device_input": "MRP Device"},
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result2["reason"] == "unknown"
# Zeroconf
async def test_zeroconf_unsupported_service_aborts(hass):
"""Test discovering unsupported zeroconf service."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data={
"type": "_dummy._tcp.local.",
"properties": {},
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "unknown"
async def test_zeroconf_add_mrp_device(hass, mrp_device, pairing):
"""Test add MRP device discovered by zeroconf."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data={
"type": "_mediaremotetv._tcp.local.",
"properties": {"UniqueIdentifier": "mrpid", "Name": "Kitchen"},
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["description_placeholders"] == {"name": "MRP Device"}
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["description_placeholders"] == {"protocol": "MRP"}
result3 = await hass.config_entries.flow.async_configure(
result["flow_id"], {"pin": 1111}
)
assert result3["type"] == "create_entry"
assert result3["data"] == {
"address": "127.0.0.1",
"credentials": {Protocol.MRP.value: "mrp_creds"},
"name": "MRP Device",
"protocol": Protocol.MRP.value,
}
async def test_zeroconf_add_dmap_device(hass, dmap_device, dmap_pin, pairing):
"""Test add DMAP device discovered by zeroconf."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["description_placeholders"] == {"name": "DMAP Device"}
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["description_placeholders"] == {"protocol": "DMAP", "pin": 1111}
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result3["type"] == "create_entry"
assert result3["data"] == {
"address": "127.0.0.1",
"credentials": {Protocol.DMAP.value: "dmap_creds"},
"name": "DMAP Device",
"protocol": Protocol.DMAP.value,
}
async def test_zeroconf_add_existing_aborts(hass, dmap_device):
"""Test start new zeroconf flow while existing flow is active aborts."""
await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
)
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "already_in_progress"
async def test_zeroconf_add_but_device_not_found(hass, mock_scan):
"""Test add device which is not found with another scan."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
)
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "no_devices_found"
async def test_zeroconf_add_existing_device(hass, dmap_device):
"""Test add already existing device from zeroconf."""
MockConfigEntry(domain="apple_tv", unique_id="dmapid").add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
)
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "already_configured"
async def test_zeroconf_unexpected_error(hass, mock_scan):
"""Test unexpected error aborts in zeroconf."""
mock_scan.side_effect = Exception
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
)
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "unknown"
# Re-configuration
async def test_reconfigure_update_credentials(hass, mrp_device, pairing):
"""Test that reconfigure flow updates config entry."""
config_entry = MockConfigEntry(domain="apple_tv", unique_id="mrpid")
config_entry.add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": "reauth"},
data={"identifier": "mrpid", "name": "apple tv"},
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["description_placeholders"] == {"protocol": "MRP"}
result3 = await hass.config_entries.flow.async_configure(
result["flow_id"], {"pin": 1111}
)
assert result3["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result3["reason"] == "already_configured"
assert config_entry.data == {
"address": "127.0.0.1",
"protocol": Protocol.MRP.value,
"name": "MRP Device",
"credentials": {Protocol.MRP.value: "mrp_creds"},
}
async def test_reconfigure_ongoing_aborts(hass, mrp_device):
"""Test start additional reconfigure flow aborts."""
data = {
"identifier": "mrpid",
"name": "Apple TV",
}
await hass.config_entries.flow.async_init(
DOMAIN, context={"source": "reauth"}, data=data
)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": "reauth"}, data=data
)
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "already_in_progress"
# Options
async def test_option_start_off(hass):
"""Test start off-option flag."""
config_entry = MockConfigEntry(
domain=DOMAIN, unique_id="dmapid", options={"start_off": False}
)
config_entry.add_to_hass(hass)
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
result2 = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={CONF_START_OFF: True}
)
assert result2["type"] == "create_entry"
assert config_entry.options[CONF_START_OFF]