Refactor Apple TV integration (#31952)
Co-authored-by: Franck Nijhof <git@frenck.dev>pull/43860/head
parent
58648019c6
commit
edb246d696
|
@ -48,7 +48,9 @@ omit =
|
|||
homeassistant/components/anel_pwrctrl/switch.py
|
||||
homeassistant/components/anthemav/media_player.py
|
||||
homeassistant/components/apcupsd/*
|
||||
homeassistant/components/apple_tv/*
|
||||
homeassistant/components/apple_tv/__init__.py
|
||||
homeassistant/components/apple_tv/media_player.py
|
||||
homeassistant/components/apple_tv/remote.py
|
||||
homeassistant/components/aqualogic/*
|
||||
homeassistant/components/aquostv/media_player.py
|
||||
homeassistant/components/arcam_fmj/media_player.py
|
||||
|
|
|
@ -37,6 +37,7 @@ homeassistant/components/amcrest/* @pnbruckner
|
|||
homeassistant/components/androidtv/* @JeffLIrion
|
||||
homeassistant/components/apache_kafka/* @bachya
|
||||
homeassistant/components/api/* @home-assistant/core
|
||||
homeassistant/components/apple_tv/* @postlund
|
||||
homeassistant/components/apprise/* @caronc
|
||||
homeassistant/components/aprs/* @PhilRW
|
||||
homeassistant/components/arcam_fmj/* @elupus
|
||||
|
|
|
@ -1,273 +1,363 @@
|
|||
"""Support for Apple TV."""
|
||||
"""The Apple TV integration."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Sequence, TypeVar, Union
|
||||
from random import randrange
|
||||
|
||||
from pyatv import AppleTVDevice, connect_to_apple_tv, scan_for_apple_tvs
|
||||
from pyatv.exceptions import DeviceAuthenticationError
|
||||
import voluptuous as vol
|
||||
from pyatv import connect, exceptions, scan
|
||||
from pyatv.const import Protocol
|
||||
|
||||
from homeassistant.components.discovery import SERVICE_APPLE_TV
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_NAME
|
||||
from homeassistant.helpers import discovery
|
||||
from homeassistant.components.media_player import DOMAIN as MP_DOMAIN
|
||||
from homeassistant.components.remote import DOMAIN as REMOTE_DOMAIN
|
||||
from homeassistant.const import (
|
||||
CONF_ADDRESS,
|
||||
CONF_NAME,
|
||||
CONF_PROTOCOL,
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
from homeassistant.helpers.entity import Entity
|
||||
|
||||
from .const import CONF_CREDENTIALS, CONF_IDENTIFIER, CONF_START_OFF, DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DOMAIN = "apple_tv"
|
||||
|
||||
SERVICE_SCAN = "apple_tv_scan"
|
||||
SERVICE_AUTHENTICATE = "apple_tv_authenticate"
|
||||
|
||||
ATTR_ATV = "atv"
|
||||
ATTR_POWER = "power"
|
||||
|
||||
CONF_LOGIN_ID = "login_id"
|
||||
CONF_START_OFF = "start_off"
|
||||
CONF_CREDENTIALS = "credentials"
|
||||
|
||||
DEFAULT_NAME = "Apple TV"
|
||||
|
||||
DATA_APPLE_TV = "data_apple_tv"
|
||||
DATA_ENTITIES = "data_apple_tv_entities"
|
||||
BACKOFF_TIME_UPPER_LIMIT = 300 # Five minutes
|
||||
|
||||
KEY_CONFIG = "apple_tv_configuring"
|
||||
NOTIFICATION_TITLE = "Apple TV Notification"
|
||||
NOTIFICATION_ID = "apple_tv_notification"
|
||||
|
||||
NOTIFICATION_AUTH_ID = "apple_tv_auth_notification"
|
||||
NOTIFICATION_AUTH_TITLE = "Apple TV Authentication"
|
||||
NOTIFICATION_SCAN_ID = "apple_tv_scan_notification"
|
||||
NOTIFICATION_SCAN_TITLE = "Apple TV Scan"
|
||||
SOURCE_REAUTH = "reauth"
|
||||
|
||||
T = TypeVar("T")
|
||||
SIGNAL_CONNECTED = "apple_tv_connected"
|
||||
SIGNAL_DISCONNECTED = "apple_tv_disconnected"
|
||||
|
||||
|
||||
# This version of ensure_list interprets an empty dict as no value
|
||||
def ensure_list(value: Union[T, Sequence[T]]) -> Sequence[T]:
|
||||
"""Wrap value in list if it is not one."""
|
||||
if value is None or (isinstance(value, dict) and not value):
|
||||
return []
|
||||
return value if isinstance(value, list) else [value]
|
||||
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.All(
|
||||
ensure_list,
|
||||
[
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_HOST): cv.string,
|
||||
vol.Required(CONF_LOGIN_ID): cv.string,
|
||||
vol.Optional(CONF_CREDENTIALS): cv.string,
|
||||
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
|
||||
vol.Optional(CONF_START_OFF, default=False): cv.boolean,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
# Currently no attributes but it might change later
|
||||
APPLE_TV_SCAN_SCHEMA = vol.Schema({})
|
||||
|
||||
APPLE_TV_AUTHENTICATE_SCHEMA = vol.Schema({ATTR_ENTITY_ID: cv.entity_ids})
|
||||
|
||||
|
||||
def request_configuration(hass, config, atv, credentials):
|
||||
"""Request configuration steps from the user."""
|
||||
configurator = hass.components.configurator
|
||||
|
||||
async def configuration_callback(callback_data):
|
||||
"""Handle the submitted configuration."""
|
||||
|
||||
pin = callback_data.get("pin")
|
||||
|
||||
try:
|
||||
await atv.airplay.finish_authentication(pin)
|
||||
hass.components.persistent_notification.async_create(
|
||||
f"Authentication succeeded!<br /><br />"
|
||||
f"Add the following to credentials: "
|
||||
f"in your apple_tv configuration:<br /><br />{credentials}",
|
||||
title=NOTIFICATION_AUTH_TITLE,
|
||||
notification_id=NOTIFICATION_AUTH_ID,
|
||||
)
|
||||
except DeviceAuthenticationError as ex:
|
||||
hass.components.persistent_notification.async_create(
|
||||
f"Authentication failed! Did you enter correct PIN?<br /><br />Details: {ex}",
|
||||
title=NOTIFICATION_AUTH_TITLE,
|
||||
notification_id=NOTIFICATION_AUTH_ID,
|
||||
)
|
||||
|
||||
hass.async_add_job(configurator.request_done, instance)
|
||||
|
||||
instance = configurator.request_config(
|
||||
"Apple TV Authentication",
|
||||
configuration_callback,
|
||||
description="Please enter PIN code shown on screen.",
|
||||
submit_caption="Confirm",
|
||||
fields=[{"id": "pin", "name": "PIN Code", "type": "password"}],
|
||||
)
|
||||
|
||||
|
||||
async def scan_apple_tvs(hass):
|
||||
"""Scan for devices and present a notification of the ones found."""
|
||||
|
||||
atvs = await scan_for_apple_tvs(hass.loop, timeout=3)
|
||||
|
||||
devices = []
|
||||
for atv in atvs:
|
||||
login_id = atv.login_id
|
||||
if login_id is None:
|
||||
login_id = "Home Sharing disabled"
|
||||
devices.append(
|
||||
f"Name: {atv.name}<br />Host: {atv.address}<br />Login ID: {login_id}"
|
||||
)
|
||||
|
||||
if not devices:
|
||||
devices = ["No device(s) found"]
|
||||
|
||||
found_devices = "<br /><br />".join(devices)
|
||||
|
||||
hass.components.persistent_notification.async_create(
|
||||
f"The following devices were found:<br /><br />{found_devices}",
|
||||
title=NOTIFICATION_SCAN_TITLE,
|
||||
notification_id=NOTIFICATION_SCAN_ID,
|
||||
)
|
||||
PLATFORMS = [MP_DOMAIN, REMOTE_DOMAIN]
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Set up the Apple TV component."""
|
||||
if DATA_APPLE_TV not in hass.data:
|
||||
hass.data[DATA_APPLE_TV] = {}
|
||||
"""Set up the Apple TV integration."""
|
||||
return True
|
||||
|
||||
async def async_service_handler(service):
|
||||
"""Handle service calls."""
|
||||
entity_ids = service.data.get(ATTR_ENTITY_ID)
|
||||
|
||||
if service.service == SERVICE_SCAN:
|
||||
hass.async_add_job(scan_apple_tvs, hass)
|
||||
return
|
||||
async def async_setup_entry(hass, entry):
|
||||
"""Set up a config entry for Apple TV."""
|
||||
manager = AppleTVManager(hass, entry)
|
||||
hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager
|
||||
|
||||
if entity_ids:
|
||||
devices = [
|
||||
device
|
||||
for device in hass.data[DATA_ENTITIES]
|
||||
if device.entity_id in entity_ids
|
||||
async def on_hass_stop(event):
|
||||
"""Stop push updates when hass stops."""
|
||||
await manager.disconnect()
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop)
|
||||
|
||||
async def setup_platforms():
|
||||
"""Set up platforms and initiate connection."""
|
||||
await asyncio.gather(
|
||||
*[
|
||||
hass.config_entries.async_forward_entry_setup(entry, component)
|
||||
for component in PLATFORMS
|
||||
]
|
||||
else:
|
||||
devices = hass.data[DATA_ENTITIES]
|
||||
|
||||
for device in devices:
|
||||
if service.service != SERVICE_AUTHENTICATE:
|
||||
continue
|
||||
|
||||
atv = device.atv
|
||||
credentials = await atv.airplay.generate_credentials()
|
||||
await atv.airplay.load_credentials(credentials)
|
||||
_LOGGER.debug("Generated new credentials: %s", credentials)
|
||||
await atv.airplay.start_authentication()
|
||||
hass.async_add_job(request_configuration, hass, config, atv, credentials)
|
||||
|
||||
async def atv_discovered(service, info):
|
||||
"""Set up an Apple TV that was auto discovered."""
|
||||
await _setup_atv(
|
||||
hass,
|
||||
config,
|
||||
{
|
||||
CONF_NAME: info["name"],
|
||||
CONF_HOST: info["host"],
|
||||
CONF_LOGIN_ID: info["properties"]["hG"],
|
||||
CONF_START_OFF: False,
|
||||
},
|
||||
)
|
||||
await manager.init()
|
||||
|
||||
discovery.async_listen(hass, SERVICE_APPLE_TV, atv_discovered)
|
||||
|
||||
tasks = [_setup_atv(hass, config, conf) for conf in config.get(DOMAIN, [])]
|
||||
if tasks:
|
||||
await asyncio.wait(tasks)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_SCAN, async_service_handler, schema=APPLE_TV_SCAN_SCHEMA
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
SERVICE_AUTHENTICATE,
|
||||
async_service_handler,
|
||||
schema=APPLE_TV_AUTHENTICATE_SCHEMA,
|
||||
)
|
||||
hass.async_create_task(setup_platforms())
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def _setup_atv(hass, hass_config, atv_config):
|
||||
"""Set up an Apple TV."""
|
||||
|
||||
name = atv_config.get(CONF_NAME)
|
||||
host = atv_config.get(CONF_HOST)
|
||||
login_id = atv_config.get(CONF_LOGIN_ID)
|
||||
start_off = atv_config.get(CONF_START_OFF)
|
||||
credentials = atv_config.get(CONF_CREDENTIALS)
|
||||
|
||||
if host in hass.data[DATA_APPLE_TV]:
|
||||
return
|
||||
|
||||
details = AppleTVDevice(name, host, login_id)
|
||||
session = async_get_clientsession(hass)
|
||||
atv = connect_to_apple_tv(details, hass.loop, session=session)
|
||||
if credentials:
|
||||
await atv.airplay.load_credentials(credentials)
|
||||
|
||||
power = AppleTVPowerManager(hass, atv, start_off)
|
||||
hass.data[DATA_APPLE_TV][host] = {ATTR_ATV: atv, ATTR_POWER: power}
|
||||
|
||||
hass.async_create_task(
|
||||
discovery.async_load_platform(
|
||||
hass, "media_player", DOMAIN, atv_config, hass_config
|
||||
async def async_unload_entry(hass, entry):
|
||||
"""Unload an Apple TV config entry."""
|
||||
unload_ok = all(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
hass.config_entries.async_forward_entry_unload(entry, platform)
|
||||
for platform in PLATFORMS
|
||||
]
|
||||
)
|
||||
)
|
||||
if unload_ok:
|
||||
manager = hass.data[DOMAIN].pop(entry.unique_id)
|
||||
await manager.disconnect()
|
||||
|
||||
hass.async_create_task(
|
||||
discovery.async_load_platform(hass, "remote", DOMAIN, atv_config, hass_config)
|
||||
)
|
||||
return unload_ok
|
||||
|
||||
|
||||
class AppleTVPowerManager:
|
||||
"""Manager for global power management of an Apple TV.
|
||||
class AppleTVEntity(Entity):
|
||||
"""Device that sends commands to an Apple TV."""
|
||||
|
||||
An instance is used per device to share the same power state between
|
||||
several platforms.
|
||||
"""
|
||||
def __init__(self, name, identifier, manager):
|
||||
"""Initialize device."""
|
||||
self.atv = None
|
||||
self.manager = manager
|
||||
self._name = name
|
||||
self._identifier = identifier
|
||||
|
||||
def __init__(self, hass, atv, is_off):
|
||||
"""Initialize power manager."""
|
||||
self.hass = hass
|
||||
self.atv = atv
|
||||
self.listeners = []
|
||||
self._is_on = not is_off
|
||||
async def async_added_to_hass(self):
|
||||
"""Handle when an entity is about to be added to Home Assistant."""
|
||||
|
||||
def init(self):
|
||||
"""Initialize power management."""
|
||||
if self._is_on:
|
||||
self.atv.push_updater.start()
|
||||
@callback
|
||||
def _async_connected(atv):
|
||||
"""Handle that a connection was made to a device."""
|
||||
self.atv = atv
|
||||
self.async_device_connected(atv)
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def _async_disconnected():
|
||||
"""Handle that a connection to a device was lost."""
|
||||
self.async_device_disconnected()
|
||||
self.atv = None
|
||||
self.async_write_ha_state()
|
||||
|
||||
self.async_on_remove(
|
||||
async_dispatcher_connect(
|
||||
self.hass, f"{SIGNAL_CONNECTED}_{self._identifier}", _async_connected
|
||||
)
|
||||
)
|
||||
self.async_on_remove(
|
||||
async_dispatcher_connect(
|
||||
self.hass,
|
||||
f"{SIGNAL_DISCONNECTED}_{self._identifier}",
|
||||
_async_disconnected,
|
||||
)
|
||||
)
|
||||
|
||||
def async_device_connected(self, atv):
|
||||
"""Handle when connection is made to device."""
|
||||
|
||||
def async_device_disconnected(self):
|
||||
"""Handle when connection was lost to device."""
|
||||
|
||||
@property
|
||||
def turned_on(self):
|
||||
"""Return true if device is on or off."""
|
||||
return self._is_on
|
||||
def device_info(self):
|
||||
"""Return the device info."""
|
||||
return {
|
||||
"identifiers": {(DOMAIN, self._identifier)},
|
||||
"manufacturer": "Apple",
|
||||
"name": self.name,
|
||||
}
|
||||
|
||||
def set_power_on(self, value):
|
||||
"""Change if a device is on or off."""
|
||||
if value != self._is_on:
|
||||
self._is_on = value
|
||||
if not self._is_on:
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the device."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return a unique ID."""
|
||||
return self._identifier
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
"""No polling needed for Apple TV."""
|
||||
return False
|
||||
|
||||
|
||||
class AppleTVManager:
|
||||
"""Connection and power manager for an Apple TV.
|
||||
|
||||
An instance is used per device to share the same power state between
|
||||
several platforms. It also manages scanning and connection establishment
|
||||
in case of problems.
|
||||
"""
|
||||
|
||||
def __init__(self, hass, config_entry):
|
||||
"""Initialize power manager."""
|
||||
self.config_entry = config_entry
|
||||
self.hass = hass
|
||||
self.atv = None
|
||||
self._is_on = not config_entry.options.get(CONF_START_OFF, False)
|
||||
self._connection_attempts = 0
|
||||
self._connection_was_lost = False
|
||||
self._task = None
|
||||
|
||||
async def init(self):
|
||||
"""Initialize power management."""
|
||||
if self._is_on:
|
||||
await self.connect()
|
||||
|
||||
def connection_lost(self, _):
|
||||
"""Device was unexpectedly disconnected.
|
||||
|
||||
This is a callback function from pyatv.interface.DeviceListener.
|
||||
"""
|
||||
_LOGGER.warning('Connection lost to Apple TV "%s"', self.atv.name)
|
||||
if self.atv:
|
||||
self.atv.close()
|
||||
self.atv = None
|
||||
self._connection_was_lost = True
|
||||
self._dispatch_send(SIGNAL_DISCONNECTED)
|
||||
self._start_connect_loop()
|
||||
|
||||
def connection_closed(self):
|
||||
"""Device connection was (intentionally) closed.
|
||||
|
||||
This is a callback function from pyatv.interface.DeviceListener.
|
||||
"""
|
||||
if self.atv:
|
||||
self.atv.close()
|
||||
self.atv = None
|
||||
self._dispatch_send(SIGNAL_DISCONNECTED)
|
||||
self._start_connect_loop()
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to device."""
|
||||
self._is_on = True
|
||||
self._start_connect_loop()
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from device."""
|
||||
_LOGGER.debug("Disconnecting from device")
|
||||
self._is_on = False
|
||||
try:
|
||||
if self.atv:
|
||||
self.atv.push_updater.listener = None
|
||||
self.atv.push_updater.stop()
|
||||
else:
|
||||
self.atv.push_updater.start()
|
||||
self.atv.close()
|
||||
self.atv = None
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("An error occurred while disconnecting")
|
||||
|
||||
for listener in self.listeners:
|
||||
self.hass.async_create_task(listener.async_update_ha_state())
|
||||
def _start_connect_loop(self):
|
||||
"""Start background connect loop to device."""
|
||||
if not self._task and self.atv is None and self._is_on:
|
||||
self._task = asyncio.create_task(self._connect_loop())
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Not starting connect loop (%s, %s)", self.atv is None, self._is_on
|
||||
)
|
||||
|
||||
async def _connect_loop(self):
|
||||
"""Connect loop background task function."""
|
||||
_LOGGER.debug("Starting connect loop")
|
||||
|
||||
# Try to find device and connect as long as the user has said that
|
||||
# we are allowed to connect and we are not already connected.
|
||||
while self._is_on and self.atv is None:
|
||||
try:
|
||||
conf = await self._scan()
|
||||
if conf:
|
||||
await self._connect(conf)
|
||||
except exceptions.AuthenticationError:
|
||||
self._auth_problem()
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Failed to connect")
|
||||
self.atv = None
|
||||
|
||||
if self.atv is None:
|
||||
self._connection_attempts += 1
|
||||
backoff = min(
|
||||
randrange(2 ** self._connection_attempts), BACKOFF_TIME_UPPER_LIMIT
|
||||
)
|
||||
|
||||
_LOGGER.debug("Reconnecting in %d seconds", backoff)
|
||||
await asyncio.sleep(backoff)
|
||||
|
||||
_LOGGER.debug("Connect loop ended")
|
||||
self._task = None
|
||||
|
||||
def _auth_problem(self):
|
||||
"""Problem to authenticate occurred that needs intervention."""
|
||||
_LOGGER.debug("Authentication error, reconfigure integration")
|
||||
|
||||
name = self.config_entry.data.get(CONF_NAME)
|
||||
identifier = self.config_entry.unique_id
|
||||
|
||||
self.hass.components.persistent_notification.create(
|
||||
"An irrecoverable connection problem occurred when connecting to "
|
||||
f"`f{name}`. Please go to the Integrations page and reconfigure it",
|
||||
title=NOTIFICATION_TITLE,
|
||||
notification_id=NOTIFICATION_ID,
|
||||
)
|
||||
|
||||
# Add to event queue as this function is called from a task being
|
||||
# cancelled from disconnect
|
||||
asyncio.create_task(self.disconnect())
|
||||
|
||||
self.hass.async_create_task(
|
||||
self.hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": SOURCE_REAUTH},
|
||||
data={CONF_NAME: name, CONF_IDENTIFIER: identifier},
|
||||
)
|
||||
)
|
||||
|
||||
async def _scan(self):
|
||||
"""Try to find device by scanning for it."""
|
||||
identifier = self.config_entry.unique_id
|
||||
address = self.config_entry.data[CONF_ADDRESS]
|
||||
protocol = Protocol(self.config_entry.data[CONF_PROTOCOL])
|
||||
|
||||
_LOGGER.debug("Discovering device %s", identifier)
|
||||
atvs = await scan(
|
||||
self.hass.loop, identifier=identifier, protocol=protocol, hosts=[address]
|
||||
)
|
||||
if atvs:
|
||||
return atvs[0]
|
||||
|
||||
_LOGGER.debug(
|
||||
"Failed to find device %s with address %s, trying to scan",
|
||||
identifier,
|
||||
address,
|
||||
)
|
||||
|
||||
atvs = await scan(self.hass.loop, identifier=identifier, protocol=protocol)
|
||||
if atvs:
|
||||
return atvs[0]
|
||||
|
||||
_LOGGER.debug("Failed to find device %s, trying later", identifier)
|
||||
|
||||
return None
|
||||
|
||||
async def _connect(self, conf):
|
||||
"""Connect to device."""
|
||||
credentials = self.config_entry.data[CONF_CREDENTIALS]
|
||||
session = async_get_clientsession(self.hass)
|
||||
|
||||
for protocol, creds in credentials.items():
|
||||
conf.set_credentials(Protocol(int(protocol)), creds)
|
||||
|
||||
_LOGGER.debug("Connecting to device %s", self.config_entry.data[CONF_NAME])
|
||||
self.atv = await connect(conf, self.hass.loop, session=session)
|
||||
self.atv.listener = self
|
||||
|
||||
self._dispatch_send(SIGNAL_CONNECTED, self.atv)
|
||||
self._address_updated(str(conf.address))
|
||||
|
||||
self._connection_attempts = 0
|
||||
if self._connection_was_lost:
|
||||
_LOGGER.info(
|
||||
'Connection was re-established to Apple TV "%s"', self.atv.service.name
|
||||
)
|
||||
self._connection_was_lost = False
|
||||
|
||||
@property
|
||||
def is_connecting(self):
|
||||
"""Return true if connection is in progress."""
|
||||
return self._task is not None
|
||||
|
||||
def _address_updated(self, address):
|
||||
"""Update cached address in config entry."""
|
||||
_LOGGER.debug("Changing address to %s", address)
|
||||
self.hass.config_entries.async_update_entry(
|
||||
self.config_entry, data={**self.config_entry.data, CONF_ADDRESS: address}
|
||||
)
|
||||
|
||||
def _dispatch_send(self, signal, *args):
|
||||
"""Dispatch a signal to all entities managed by this manager."""
|
||||
async_dispatcher_send(
|
||||
self.hass, f"{signal}_{self.config_entry.unique_id}", *args
|
||||
)
|
||||
|
|
|
@ -0,0 +1,408 @@
|
|||
"""Config flow for Apple TV integration."""
|
||||
from ipaddress import ip_address
|
||||
import logging
|
||||
from random import randrange
|
||||
|
||||
from pyatv import exceptions, pair, scan
|
||||
from pyatv.const import Protocol
|
||||
from pyatv.convert import protocol_str
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.const import (
|
||||
CONF_ADDRESS,
|
||||
CONF_NAME,
|
||||
CONF_PIN,
|
||||
CONF_PROTOCOL,
|
||||
CONF_TYPE,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import AbortFlow
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
|
||||
from .const import CONF_CREDENTIALS, CONF_IDENTIFIER, CONF_START_OFF
|
||||
from .const import DOMAIN # pylint: disable=unused-import
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DEVICE_INPUT = "device_input"
|
||||
|
||||
INPUT_PIN_SCHEMA = vol.Schema({vol.Required(CONF_PIN, default=None): int})
|
||||
|
||||
DEFAULT_START_OFF = False
|
||||
PROTOCOL_PRIORITY = [Protocol.MRP, Protocol.DMAP, Protocol.AirPlay]
|
||||
|
||||
|
||||
async def device_scan(identifier, loop, cache=None):
|
||||
"""Scan for a specific device using identifier as filter."""
|
||||
|
||||
def _filter_device(dev):
|
||||
if identifier is None:
|
||||
return True
|
||||
if identifier == str(dev.address):
|
||||
return True
|
||||
if identifier == dev.name:
|
||||
return True
|
||||
return any([service.identifier == identifier for service in dev.services])
|
||||
|
||||
def _host_filter():
|
||||
try:
|
||||
return [ip_address(identifier)]
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
if cache:
|
||||
matches = [atv for atv in cache if _filter_device(atv)]
|
||||
if matches:
|
||||
return cache, matches[0]
|
||||
|
||||
for hosts in [_host_filter(), None]:
|
||||
scan_result = await scan(loop, timeout=3, hosts=hosts)
|
||||
matches = [atv for atv in scan_result if _filter_device(atv)]
|
||||
|
||||
if matches:
|
||||
return scan_result, matches[0]
|
||||
|
||||
return scan_result, None
|
||||
|
||||
|
||||
def is_valid_credentials(credentials):
|
||||
"""Verify that credentials are valid for establishing a connection."""
|
||||
return (
|
||||
credentials.get(Protocol.MRP.value) is not None
|
||||
or credentials.get(Protocol.DMAP.value) is not None
|
||||
)
|
||||
|
||||
|
||||
class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Apple TV."""
|
||||
|
||||
VERSION = 1
|
||||
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(config_entry):
|
||||
"""Get options flow for this handler."""
|
||||
return AppleTVOptionsFlow(config_entry)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a new AppleTVConfigFlow."""
|
||||
self.target_device = None
|
||||
self.scan_result = None
|
||||
self.atv = None
|
||||
self.protocol = None
|
||||
self.pairing = None
|
||||
self.credentials = {} # Protocol -> credentials
|
||||
|
||||
async def async_step_reauth(self, info):
|
||||
"""Handle initial step when updating invalid credentials."""
|
||||
await self.async_set_unique_id(info[CONF_IDENTIFIER])
|
||||
self.target_device = info[CONF_IDENTIFIER]
|
||||
|
||||
# pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167
|
||||
self.context["title_placeholders"] = {"name": info[CONF_NAME]}
|
||||
|
||||
# pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167
|
||||
self.context["identifier"] = self.unique_id
|
||||
return await self.async_step_reconfigure()
|
||||
|
||||
async def async_step_reconfigure(self, user_input=None):
|
||||
"""Inform user that reconfiguration is about to start."""
|
||||
if user_input is not None:
|
||||
return await self.async_find_device_wrapper(
|
||||
self.async_begin_pairing, allow_exist=True
|
||||
)
|
||||
|
||||
return self.async_show_form(step_id="reconfigure")
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
"""Handle the initial step."""
|
||||
# Be helpful to the user and look for devices
|
||||
if self.scan_result is None:
|
||||
self.scan_result, _ = await device_scan(None, self.hass.loop)
|
||||
|
||||
errors = {}
|
||||
default_suggestion = self._prefill_identifier()
|
||||
if user_input is not None:
|
||||
self.target_device = user_input[DEVICE_INPUT]
|
||||
try:
|
||||
await self.async_find_device()
|
||||
except DeviceNotFound:
|
||||
errors["base"] = "no_devices_found"
|
||||
except DeviceAlreadyConfigured:
|
||||
errors["base"] = "already_configured"
|
||||
except exceptions.NoServiceError:
|
||||
errors["base"] = "no_usable_service"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
else:
|
||||
await self.async_set_unique_id(
|
||||
self.atv.identifier, raise_on_progress=False
|
||||
)
|
||||
return await self.async_step_confirm()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=vol.Schema(
|
||||
{vol.Required(DEVICE_INPUT, default=default_suggestion): str}
|
||||
),
|
||||
errors=errors,
|
||||
description_placeholders={"devices": self._devices_str()},
|
||||
)
|
||||
|
||||
async def async_step_zeroconf(self, discovery_info):
|
||||
"""Handle device found via zeroconf."""
|
||||
service_type = discovery_info[CONF_TYPE]
|
||||
properties = discovery_info["properties"]
|
||||
|
||||
if service_type == "_mediaremotetv._tcp.local.":
|
||||
identifier = properties["UniqueIdentifier"]
|
||||
name = properties["Name"]
|
||||
elif service_type == "_touch-able._tcp.local.":
|
||||
identifier = discovery_info["name"].split(".")[0]
|
||||
name = properties["CtlN"]
|
||||
else:
|
||||
return self.async_abort(reason="unknown")
|
||||
|
||||
await self.async_set_unique_id(identifier)
|
||||
self._abort_if_unique_id_configured()
|
||||
|
||||
# pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167
|
||||
self.context["identifier"] = self.unique_id
|
||||
self.context["title_placeholders"] = {"name": name}
|
||||
self.target_device = identifier
|
||||
return await self.async_find_device_wrapper(self.async_step_confirm)
|
||||
|
||||
async def async_find_device_wrapper(self, next_func, allow_exist=False):
|
||||
"""Find a specific device and call another function when done.
|
||||
|
||||
This function will do error handling and bail out when an error
|
||||
occurs.
|
||||
"""
|
||||
try:
|
||||
await self.async_find_device(allow_exist)
|
||||
except DeviceNotFound:
|
||||
return self.async_abort(reason="no_devices_found")
|
||||
except DeviceAlreadyConfigured:
|
||||
return self.async_abort(reason="already_configured")
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
return self.async_abort(reason="unknown")
|
||||
|
||||
return await next_func()
|
||||
|
||||
async def async_find_device(self, allow_exist=False):
|
||||
"""Scan for the selected device to discover services."""
|
||||
self.scan_result, self.atv = await device_scan(
|
||||
self.target_device, self.hass.loop, cache=self.scan_result
|
||||
)
|
||||
if not self.atv:
|
||||
raise DeviceNotFound()
|
||||
|
||||
self.protocol = self.atv.main_service().protocol
|
||||
|
||||
if not allow_exist:
|
||||
for identifier in self.atv.all_identifiers:
|
||||
if identifier in self._async_current_ids():
|
||||
raise DeviceAlreadyConfigured()
|
||||
|
||||
# If credentials were found, save them
|
||||
for service in self.atv.services:
|
||||
if service.credentials:
|
||||
self.credentials[service.protocol.value] = service.credentials
|
||||
|
||||
async def async_step_confirm(self, user_input=None):
|
||||
"""Handle user-confirmation of discovered node."""
|
||||
if user_input is not None:
|
||||
return await self.async_begin_pairing()
|
||||
return self.async_show_form(
|
||||
step_id="confirm", description_placeholders={"name": self.atv.name}
|
||||
)
|
||||
|
||||
async def async_begin_pairing(self):
|
||||
"""Start pairing process for the next available protocol."""
|
||||
self.protocol = self._next_protocol_to_pair()
|
||||
|
||||
# Dispose previous pairing sessions
|
||||
if self.pairing is not None:
|
||||
await self.pairing.close()
|
||||
self.pairing = None
|
||||
|
||||
# Any more protocols to pair? Else bail out here
|
||||
if not self.protocol:
|
||||
await self.async_set_unique_id(self.atv.main_service().identifier)
|
||||
return self._async_get_entry(
|
||||
self.atv.main_service().protocol,
|
||||
self.atv.name,
|
||||
self.credentials,
|
||||
self.atv.address,
|
||||
)
|
||||
|
||||
# Initiate the pairing process
|
||||
abort_reason = None
|
||||
session = async_get_clientsession(self.hass)
|
||||
self.pairing = await pair(
|
||||
self.atv, self.protocol, self.hass.loop, session=session
|
||||
)
|
||||
try:
|
||||
await self.pairing.begin()
|
||||
except exceptions.ConnectionFailedError:
|
||||
return await self.async_step_service_problem()
|
||||
except exceptions.BackOffError:
|
||||
abort_reason = "backoff"
|
||||
except exceptions.PairingError:
|
||||
_LOGGER.exception("Authentication problem")
|
||||
abort_reason = "invalid_auth"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
abort_reason = "unknown"
|
||||
|
||||
if abort_reason:
|
||||
if self.pairing:
|
||||
await self.pairing.close()
|
||||
return self.async_abort(reason=abort_reason)
|
||||
|
||||
# Choose step depending on if PIN is required from user or not
|
||||
if self.pairing.device_provides_pin:
|
||||
return await self.async_step_pair_with_pin()
|
||||
|
||||
return await self.async_step_pair_no_pin()
|
||||
|
||||
async def async_step_pair_with_pin(self, user_input=None):
|
||||
"""Handle pairing step where a PIN is required from the user."""
|
||||
errors = {}
|
||||
if user_input is not None:
|
||||
try:
|
||||
self.pairing.pin(user_input[CONF_PIN])
|
||||
await self.pairing.finish()
|
||||
self.credentials[self.protocol.value] = self.pairing.service.credentials
|
||||
return await self.async_begin_pairing()
|
||||
except exceptions.PairingError:
|
||||
_LOGGER.exception("Authentication problem")
|
||||
errors["base"] = "invalid_auth"
|
||||
except AbortFlow:
|
||||
raise
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="pair_with_pin",
|
||||
data_schema=INPUT_PIN_SCHEMA,
|
||||
errors=errors,
|
||||
description_placeholders={"protocol": protocol_str(self.protocol)},
|
||||
)
|
||||
|
||||
async def async_step_pair_no_pin(self, user_input=None):
|
||||
"""Handle step where user has to enter a PIN on the device."""
|
||||
if user_input is not None:
|
||||
await self.pairing.finish()
|
||||
if self.pairing.has_paired:
|
||||
self.credentials[self.protocol.value] = self.pairing.service.credentials
|
||||
return await self.async_begin_pairing()
|
||||
|
||||
await self.pairing.close()
|
||||
return self.async_abort(reason="device_did_not_pair")
|
||||
|
||||
pin = randrange(1000, stop=10000)
|
||||
self.pairing.pin(pin)
|
||||
return self.async_show_form(
|
||||
step_id="pair_no_pin",
|
||||
description_placeholders={
|
||||
"protocol": protocol_str(self.protocol),
|
||||
"pin": pin,
|
||||
},
|
||||
)
|
||||
|
||||
async def async_step_service_problem(self, user_input=None):
|
||||
"""Inform user that a service will not be added."""
|
||||
if user_input is not None:
|
||||
self.credentials[self.protocol.value] = None
|
||||
return await self.async_begin_pairing()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="service_problem",
|
||||
description_placeholders={"protocol": protocol_str(self.protocol)},
|
||||
)
|
||||
|
||||
def _async_get_entry(self, protocol, name, credentials, address):
|
||||
if not is_valid_credentials(credentials):
|
||||
return self.async_abort(reason="invalid_config")
|
||||
|
||||
data = {
|
||||
CONF_PROTOCOL: protocol.value,
|
||||
CONF_NAME: name,
|
||||
CONF_CREDENTIALS: credentials,
|
||||
CONF_ADDRESS: str(address),
|
||||
}
|
||||
|
||||
self._abort_if_unique_id_configured(reload_on_update=False, updates=data)
|
||||
|
||||
return self.async_create_entry(title=name, data=data)
|
||||
|
||||
def _next_protocol_to_pair(self):
|
||||
def _needs_pairing(protocol):
|
||||
if self.atv.get_service(protocol) is None:
|
||||
return False
|
||||
return protocol.value not in self.credentials
|
||||
|
||||
for protocol in PROTOCOL_PRIORITY:
|
||||
if _needs_pairing(protocol):
|
||||
return protocol
|
||||
return None
|
||||
|
||||
def _devices_str(self):
|
||||
return ", ".join(
|
||||
[
|
||||
f"`{atv.name} ({atv.address})`"
|
||||
for atv in self.scan_result
|
||||
if atv.identifier not in self._async_current_ids()
|
||||
]
|
||||
)
|
||||
|
||||
def _prefill_identifier(self):
|
||||
# Return identifier (address) of one device that has not been paired with
|
||||
for atv in self.scan_result:
|
||||
if atv.identifier not in self._async_current_ids():
|
||||
return str(atv.address)
|
||||
return ""
|
||||
|
||||
|
||||
class AppleTVOptionsFlow(config_entries.OptionsFlow):
|
||||
"""Handle Apple TV options."""
|
||||
|
||||
def __init__(self, config_entry):
|
||||
"""Initialize Apple TV options flow."""
|
||||
self.config_entry = config_entry
|
||||
self.options = dict(config_entry.options)
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
"""Manage the Apple TV options."""
|
||||
if user_input is not None:
|
||||
self.options[CONF_START_OFF] = user_input[CONF_START_OFF]
|
||||
return self.async_create_entry(title="", data=self.options)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_START_OFF,
|
||||
default=self.config_entry.options.get(
|
||||
CONF_START_OFF, DEFAULT_START_OFF
|
||||
),
|
||||
): bool,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DeviceNotFound(HomeAssistantError):
|
||||
"""Error to indicate device could not be found."""
|
||||
|
||||
|
||||
class DeviceAlreadyConfigured(HomeAssistantError):
|
||||
"""Error to indicate device is already configured."""
|
|
@ -0,0 +1,11 @@
|
|||
"""Constants for the Apple TV integration."""
|
||||
|
||||
DOMAIN = "apple_tv"
|
||||
|
||||
CONF_IDENTIFIER = "identifier"
|
||||
CONF_CREDENTIALS = "credentials"
|
||||
CONF_CREDENTIALS_MRP = "mrp"
|
||||
CONF_CREDENTIALS_DMAP = "dmap"
|
||||
CONF_CREDENTIALS_AIRPLAY = "airplay"
|
||||
|
||||
CONF_START_OFF = "start_off"
|
|
@ -1,9 +1,17 @@
|
|||
{
|
||||
"domain": "apple_tv",
|
||||
"name": "Apple TV",
|
||||
"config_flow": true,
|
||||
"documentation": "https://www.home-assistant.io/integrations/apple_tv",
|
||||
"requirements": ["pyatv==0.3.13"],
|
||||
"dependencies": ["configurator"],
|
||||
"requirements": [
|
||||
"pyatv==0.7.3"
|
||||
],
|
||||
"zeroconf": [
|
||||
"_mediaremotetv._tcp.local.",
|
||||
"_touch-able._tcp.local."
|
||||
],
|
||||
"after_dependencies": ["discovery"],
|
||||
"codeowners": []
|
||||
"codeowners": [
|
||||
"@postlund"
|
||||
]
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Support for Apple TV media player."""
|
||||
import logging
|
||||
|
||||
import pyatv.const as atv_const
|
||||
from pyatv.const import DeviceState, MediaType
|
||||
|
||||
from homeassistant.components.media_player import MediaPlayerEntity
|
||||
from homeassistant.components.media_player.const import (
|
||||
|
@ -19,9 +19,7 @@ from homeassistant.components.media_player.const import (
|
|||
SUPPORT_TURN_ON,
|
||||
)
|
||||
from homeassistant.const import (
|
||||
CONF_HOST,
|
||||
CONF_NAME,
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
STATE_IDLE,
|
||||
STATE_OFF,
|
||||
STATE_PAUSED,
|
||||
|
@ -31,10 +29,13 @@ from homeassistant.const import (
|
|||
from homeassistant.core import callback
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from . import ATTR_ATV, ATTR_POWER, DATA_APPLE_TV, DATA_ENTITIES
|
||||
from . import AppleTVEntity
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PARALLEL_UPDATES = 0
|
||||
|
||||
SUPPORT_APPLE_TV = (
|
||||
SUPPORT_TURN_ON
|
||||
| SUPPORT_TURN_OFF
|
||||
|
@ -48,108 +49,61 @@ SUPPORT_APPLE_TV = (
|
|||
)
|
||||
|
||||
|
||||
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
|
||||
"""Set up the Apple TV platform."""
|
||||
if not discovery_info:
|
||||
return
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Load Apple TV media player based on a config entry."""
|
||||
name = config_entry.data[CONF_NAME]
|
||||
manager = hass.data[DOMAIN][config_entry.unique_id]
|
||||
async_add_entities([AppleTvMediaPlayer(name, config_entry.unique_id, manager)])
|
||||
|
||||
# Manage entity cache for service handler
|
||||
if DATA_ENTITIES not in hass.data:
|
||||
hass.data[DATA_ENTITIES] = []
|
||||
|
||||
name = discovery_info[CONF_NAME]
|
||||
host = discovery_info[CONF_HOST]
|
||||
atv = hass.data[DATA_APPLE_TV][host][ATTR_ATV]
|
||||
power = hass.data[DATA_APPLE_TV][host][ATTR_POWER]
|
||||
entity = AppleTvDevice(atv, name, power)
|
||||
class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
|
||||
"""Representation of an Apple TV media player."""
|
||||
|
||||
def __init__(self, name, identifier, manager, **kwargs):
|
||||
"""Initialize the Apple TV media player."""
|
||||
super().__init__(name, identifier, manager, **kwargs)
|
||||
self._playing = None
|
||||
|
||||
@callback
|
||||
def on_hass_stop(event):
|
||||
"""Stop push updates when hass stops."""
|
||||
atv.push_updater.stop()
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop)
|
||||
|
||||
if entity not in hass.data[DATA_ENTITIES]:
|
||||
hass.data[DATA_ENTITIES].append(entity)
|
||||
|
||||
async_add_entities([entity])
|
||||
|
||||
|
||||
class AppleTvDevice(MediaPlayerEntity):
|
||||
"""Representation of an Apple TV device."""
|
||||
|
||||
def __init__(self, atv, name, power):
|
||||
"""Initialize the Apple TV device."""
|
||||
self.atv = atv
|
||||
self._name = name
|
||||
self._playing = None
|
||||
self._power = power
|
||||
self._power.listeners.append(self)
|
||||
def async_device_connected(self, atv):
|
||||
"""Handle when connection is made to device."""
|
||||
self.atv.push_updater.listener = self
|
||||
self.atv.push_updater.start()
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""Handle when an entity is about to be added to Home Assistant."""
|
||||
self._power.init()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the device."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return a unique ID."""
|
||||
return self.atv.metadata.device_id
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
"""No polling needed."""
|
||||
return False
|
||||
@callback
|
||||
def async_device_disconnected(self):
|
||||
"""Handle when connection was lost to device."""
|
||||
self.atv.push_updater.stop()
|
||||
self.atv.push_updater.listener = None
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""Return the state of the device."""
|
||||
if not self._power.turned_on:
|
||||
if self.manager.is_connecting:
|
||||
return None
|
||||
if self.atv is None:
|
||||
return STATE_OFF
|
||||
|
||||
if self._playing:
|
||||
|
||||
state = self._playing.play_state
|
||||
if state in (
|
||||
atv_const.PLAY_STATE_IDLE,
|
||||
atv_const.PLAY_STATE_NO_MEDIA,
|
||||
atv_const.PLAY_STATE_LOADING,
|
||||
):
|
||||
state = self._playing.device_state
|
||||
if state in (DeviceState.Idle, DeviceState.Loading):
|
||||
return STATE_IDLE
|
||||
if state == atv_const.PLAY_STATE_PLAYING:
|
||||
if state == DeviceState.Playing:
|
||||
return STATE_PLAYING
|
||||
if state in (
|
||||
atv_const.PLAY_STATE_PAUSED,
|
||||
atv_const.PLAY_STATE_FAST_FORWARD,
|
||||
atv_const.PLAY_STATE_FAST_BACKWARD,
|
||||
atv_const.PLAY_STATE_STOPPED,
|
||||
):
|
||||
# Catch fast forward/backward here so "play" is default action
|
||||
if state in (DeviceState.Paused, DeviceState.Seeking, DeviceState.Stopped):
|
||||
return STATE_PAUSED
|
||||
return STATE_STANDBY # Bad or unknown state?
|
||||
return None
|
||||
|
||||
@callback
|
||||
def playstatus_update(self, updater, playing):
|
||||
def playstatus_update(self, _, playing):
|
||||
"""Print what is currently playing when it changes."""
|
||||
self._playing = playing
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def playstatus_error(self, updater, exception):
|
||||
def playstatus_error(self, _, exception):
|
||||
"""Inform about an error and restart push updates."""
|
||||
_LOGGER.warning("A %s error occurred: %s", exception.__class__, exception)
|
||||
|
||||
# This will wait 10 seconds before restarting push updates. If the
|
||||
# connection continues to fail, it will flood the log (every 10
|
||||
# seconds) until it succeeds. A better approach should probably be
|
||||
# implemented here later.
|
||||
updater.start(initial_delay=10)
|
||||
self._playing = None
|
||||
self.async_write_ha_state()
|
||||
|
||||
|
@ -157,50 +111,53 @@ class AppleTvDevice(MediaPlayerEntity):
|
|||
def media_content_type(self):
|
||||
"""Content type of current playing media."""
|
||||
if self._playing:
|
||||
|
||||
media_type = self._playing.media_type
|
||||
if media_type == atv_const.MEDIA_TYPE_VIDEO:
|
||||
return MEDIA_TYPE_VIDEO
|
||||
if media_type == atv_const.MEDIA_TYPE_MUSIC:
|
||||
return MEDIA_TYPE_MUSIC
|
||||
if media_type == atv_const.MEDIA_TYPE_TV:
|
||||
return MEDIA_TYPE_TVSHOW
|
||||
return {
|
||||
MediaType.Video: MEDIA_TYPE_VIDEO,
|
||||
MediaType.Music: MEDIA_TYPE_MUSIC,
|
||||
MediaType.TV: MEDIA_TYPE_TVSHOW,
|
||||
}.get(self._playing.media_type)
|
||||
return None
|
||||
|
||||
@property
|
||||
def media_duration(self):
|
||||
"""Duration of current playing media in seconds."""
|
||||
if self._playing:
|
||||
return self._playing.total_time
|
||||
return None
|
||||
|
||||
@property
|
||||
def media_position(self):
|
||||
"""Position of current playing media in seconds."""
|
||||
if self._playing:
|
||||
return self._playing.position
|
||||
return None
|
||||
|
||||
@property
|
||||
def media_position_updated_at(self):
|
||||
"""Last valid time of media position."""
|
||||
state = self.state
|
||||
if state in (STATE_PLAYING, STATE_PAUSED):
|
||||
if self.state in (STATE_PLAYING, STATE_PAUSED):
|
||||
return dt_util.utcnow()
|
||||
return None
|
||||
|
||||
async def async_play_media(self, media_type, media_id, **kwargs):
|
||||
"""Send the play_media command to the media player."""
|
||||
await self.atv.airplay.play_url(media_id)
|
||||
await self.atv.stream.play_url(media_id)
|
||||
|
||||
@property
|
||||
def media_image_hash(self):
|
||||
"""Hash value for media image."""
|
||||
state = self.state
|
||||
if self._playing and state not in [STATE_OFF, STATE_IDLE]:
|
||||
return self._playing.hash
|
||||
if self._playing and state not in [None, STATE_OFF, STATE_IDLE]:
|
||||
return self.atv.metadata.artwork_id
|
||||
return None
|
||||
|
||||
async def async_get_media_image(self):
|
||||
"""Fetch media image of current playing image."""
|
||||
state = self.state
|
||||
if self._playing and state not in [STATE_OFF, STATE_IDLE]:
|
||||
return (await self.atv.metadata.artwork()), "image/png"
|
||||
artwork = await self.atv.metadata.artwork()
|
||||
if artwork:
|
||||
return artwork.bytes, artwork.mimetype
|
||||
|
||||
return None, None
|
||||
|
||||
|
@ -208,12 +165,8 @@ class AppleTvDevice(MediaPlayerEntity):
|
|||
def media_title(self):
|
||||
"""Title of current playing media."""
|
||||
if self._playing:
|
||||
if self.state == STATE_IDLE:
|
||||
return "Nothing playing"
|
||||
title = self._playing.title
|
||||
return title if title else "No title"
|
||||
|
||||
return f"Establishing a connection to {self._name}..."
|
||||
return self._playing.title
|
||||
return None
|
||||
|
||||
@property
|
||||
def supported_features(self):
|
||||
|
@ -222,22 +175,22 @@ class AppleTvDevice(MediaPlayerEntity):
|
|||
|
||||
async def async_turn_on(self):
|
||||
"""Turn the media player on."""
|
||||
self._power.set_power_on(True)
|
||||
await self.manager.connect()
|
||||
|
||||
async def async_turn_off(self):
|
||||
"""Turn the media player off."""
|
||||
self._playing = None
|
||||
self._power.set_power_on(False)
|
||||
await self.manager.disconnect()
|
||||
|
||||
async def async_media_play_pause(self):
|
||||
"""Pause media on media player."""
|
||||
if not self._playing:
|
||||
return
|
||||
state = self.state
|
||||
if state == STATE_PAUSED:
|
||||
await self.atv.remote_control.play()
|
||||
elif state == STATE_PLAYING:
|
||||
await self.atv.remote_control.pause()
|
||||
if self._playing:
|
||||
state = self.state
|
||||
if state == STATE_PAUSED:
|
||||
await self.atv.remote_control.play()
|
||||
elif state == STATE_PLAYING:
|
||||
await self.atv.remote_control.pause()
|
||||
return None
|
||||
|
||||
async def async_media_play(self):
|
||||
"""Play media."""
|
||||
|
|
|
@ -1,46 +1,32 @@
|
|||
"""Remote control support for Apple TV."""
|
||||
from homeassistant.components import remote
|
||||
from homeassistant.const import CONF_HOST, CONF_NAME
|
||||
|
||||
from . import ATTR_ATV, ATTR_POWER, DATA_APPLE_TV
|
||||
import logging
|
||||
|
||||
from homeassistant.components.remote import RemoteEntity
|
||||
from homeassistant.const import CONF_NAME
|
||||
|
||||
from . import AppleTVEntity
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PARALLEL_UPDATES = 0
|
||||
|
||||
|
||||
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
|
||||
"""Set up the Apple TV remote platform."""
|
||||
if not discovery_info:
|
||||
return
|
||||
|
||||
name = discovery_info[CONF_NAME]
|
||||
host = discovery_info[CONF_HOST]
|
||||
atv = hass.data[DATA_APPLE_TV][host][ATTR_ATV]
|
||||
power = hass.data[DATA_APPLE_TV][host][ATTR_POWER]
|
||||
async_add_entities([AppleTVRemote(atv, power, name)])
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Load Apple TV remote based on a config entry."""
|
||||
name = config_entry.data[CONF_NAME]
|
||||
manager = hass.data[DOMAIN][config_entry.unique_id]
|
||||
async_add_entities([AppleTVRemote(name, config_entry.unique_id, manager)])
|
||||
|
||||
|
||||
class AppleTVRemote(remote.RemoteEntity):
|
||||
class AppleTVRemote(AppleTVEntity, RemoteEntity):
|
||||
"""Device that sends commands to an Apple TV."""
|
||||
|
||||
def __init__(self, atv, power, name):
|
||||
"""Initialize device."""
|
||||
self._atv = atv
|
||||
self._name = name
|
||||
self._power = power
|
||||
self._power.listeners.append(self)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the device."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return a unique ID."""
|
||||
return self._atv.metadata.device_id
|
||||
|
||||
@property
|
||||
def is_on(self):
|
||||
"""Return true if device is on."""
|
||||
return self._power.turned_on
|
||||
return self.atv is not None
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
|
@ -48,23 +34,21 @@ class AppleTVRemote(remote.RemoteEntity):
|
|||
return False
|
||||
|
||||
async def async_turn_on(self, **kwargs):
|
||||
"""Turn the device on.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
self._power.set_power_on(True)
|
||||
"""Turn the device on."""
|
||||
await self.manager.connect()
|
||||
|
||||
async def async_turn_off(self, **kwargs):
|
||||
"""Turn the device off.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
self._power.set_power_on(False)
|
||||
"""Turn the device off."""
|
||||
await self.manager.disconnect()
|
||||
|
||||
async def async_send_command(self, command, **kwargs):
|
||||
"""Send a command to one device."""
|
||||
if not self.is_on:
|
||||
_LOGGER.error("Unable to send commands, not connected to %s", self._name)
|
||||
return
|
||||
|
||||
for single_command in command:
|
||||
if not hasattr(self._atv.remote_control, single_command):
|
||||
if not hasattr(self.atv.remote_control, single_command):
|
||||
continue
|
||||
|
||||
await getattr(self._atv.remote_control, single_command)()
|
||||
await getattr(self.atv.remote_control, single_command)()
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
apple_tv_authenticate:
|
||||
description: Start AirPlay device authentication.
|
||||
fields:
|
||||
entity_id:
|
||||
description: Name(s) of entities to authenticate with.
|
||||
example: media_player.apple_tv
|
||||
apple_tv_scan:
|
||||
description: Scan for Apple TV devices.
|
|
@ -0,0 +1,64 @@
|
|||
{
|
||||
"title": "Apple TV",
|
||||
"config": {
|
||||
"flow_title": "Apple TV: {name}",
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "Setup a new Apple TV",
|
||||
"description": "Start by entering the device name (e.g. Kitchen or Bedroom) or IP address of the Apple TV you want to add. If any devices were automatically found on your network, they are shown below.\n\nIf you cannot see your device or experience any issues, try specifying the device IP address.\n\n{devices}",
|
||||
"data": {
|
||||
"device_input": "Device"
|
||||
}
|
||||
},
|
||||
"reconfigure": {
|
||||
"title": "Device reconfiguration",
|
||||
"description": "This Apple TV is experiencing some connection difficulties and must be reconfigured."
|
||||
},
|
||||
"pair_with_pin": {
|
||||
"title": "Pairing",
|
||||
"description": "Pairing is required for the `{protocol}` protocol. Please enter the PIN code displayed on screen. Leading zeros shall be omitted, i.e. enter 123 if the displayed code is 0123.",
|
||||
"data": {
|
||||
"pin": "[%key:common::config_flow::data::pin%]"
|
||||
}
|
||||
},
|
||||
"pair_no_pin": {
|
||||
"title": "Pairing",
|
||||
"description": "Pairing is required for the `{protocol}` service. Please enter PIN {pin} on your Apple TV to continue."
|
||||
},
|
||||
"service_problem": {
|
||||
"title": "Failed to add service",
|
||||
"description": "A problem occurred while pairing protocol `{protocol}`. It will be ignored."
|
||||
},
|
||||
"confirm": {
|
||||
"title": "Confirm adding Apple TV",
|
||||
"description": "You are about to add the Apple TV named `{name}` to Home Assistant.\n\n**To complete the process, you may have to enter multiple PIN codes.**\n\nPlease note that you will *not* be able to power off your Apple TV with this integration. Only the media player in Home Assistant will turn off!"
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"no_devices_found": "[%key:common::config_flow::abort::no_devices_found%]",
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
|
||||
"no_usable_service": "A device was found but could not identify any way to establish a connection to it. If you keep seeing this message, try specifying its IP address or restarting your Apple TV.",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]",
|
||||
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]"
|
||||
},
|
||||
"abort": {
|
||||
"no_devices_found": "[%key:common::config_flow::abort::no_devices_found%]",
|
||||
"already_configured_device": "[%key:common::config_flow::abort::already_configured_device%]",
|
||||
"device_did_not_pair": "No attempt to finish pairing process was made from the device.",
|
||||
"backoff": "Device does not accept pairing reqests at this time (you might have entered an invalid PIN code too many times), try again later.",
|
||||
"invalid_config": "The configuration for this device is incomplete. Please try adding it again.",
|
||||
"already_in_progress": "[%key:common::config_flow::abort::already_in_progress%]",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"description": "Configure general device settings",
|
||||
"data": {
|
||||
"start_off": "Do not turn device on when starting Home Assistant"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
{
|
||||
"title": "Apple TV",
|
||||
"config": {
|
||||
"flow_title": "Apple TV: {name}",
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "Setup a new Apple TV",
|
||||
"description": "Start by entering the device name (e.g. Kitchen or Bedroom) or IP address of the Apple TV you want to add. If any devices were automatically found on your network, they are shown below.\n\nIf you cannot see your device or experience any issues, try specifying the device IP address.\n\n{devices}",
|
||||
"data": {
|
||||
"device_input": "Device"
|
||||
}
|
||||
},
|
||||
"reconfigure": {
|
||||
"title": "Device reconfiguration",
|
||||
"description": "This Apple TV is experiencing some connection difficulties and must be reconfigured."
|
||||
},
|
||||
"pair_with_pin": {
|
||||
"title": "Pairing",
|
||||
"description": "Pairing is required for the `{protocol}` protocol. Please enter the PIN code displayed on screen. Leading zeros shall be omitted, i.e. enter 123 if the displayed code is 0123.",
|
||||
"data": {
|
||||
"pin": "PIN Code"
|
||||
}
|
||||
},
|
||||
"pair_no_pin": {
|
||||
"title": "Pairing",
|
||||
"description": "Pairing is required for the `{protocol}` service. Please enter PIN {pin} on your Apple TV to continue."
|
||||
},
|
||||
"service_problem": {
|
||||
"title": "Failed to add service",
|
||||
"description": "A problem occurred while pairing protocol `{protocol}`. It will be ignored."
|
||||
},
|
||||
"confirm": {
|
||||
"title": "Confirm adding Apple TV",
|
||||
"description": "You are about to add the Apple TV named `{name}` to Home Assistant.\n\n**To complete the process, you may have to enter multiple PIN codes.**\n\nPlease note that you will *not* be able to power off your Apple TV with this integration. Only the media player in Home Assistant will turn off!"
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"no_devices_found": "No devices found on the network",
|
||||
"already_configured": "Device is already configured",
|
||||
"no_usable_service": "A device was found but could not identify any way to establish a connection to it. If you keep seeing this message, try specifying its IP address or restarting your Apple TV.",
|
||||
"unknown": "Unexpected error",
|
||||
"invalid_auth": "Invalid authentication"
|
||||
},
|
||||
"abort": {
|
||||
"no_devices_found": "No devices found on the network",
|
||||
"already_configured_device": "Device is already configured",
|
||||
"device_did_not_pair": "No attempt to finish pairing process was made from the device.",
|
||||
"backoff": "Device does not accept pairing reqests at this time (you might have entered an invalid PIN code too many times), try again later.",
|
||||
"invalid_config": "The configuration for this device is incomplete. Please try adding it again.",
|
||||
"already_in_progress": "Configuration flow is already in progress",
|
||||
"unknown": "Unexpected error"
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"description": "Configure general device settings",
|
||||
"data": {
|
||||
"start_off": "Do not turn device on when starting Home Assistant"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -18,6 +18,7 @@ FLOWS = [
|
|||
"almond",
|
||||
"ambiclimate",
|
||||
"ambient_station",
|
||||
"apple_tv",
|
||||
"arcam_fmj",
|
||||
"atag",
|
||||
"august",
|
||||
|
|
|
@ -90,6 +90,11 @@ ZEROCONF = {
|
|||
"domain": "ipp"
|
||||
}
|
||||
],
|
||||
"_mediaremotetv._tcp.local.": [
|
||||
{
|
||||
"domain": "apple_tv"
|
||||
}
|
||||
],
|
||||
"_miio._udp.local.": [
|
||||
{
|
||||
"domain": "xiaomi_aqara"
|
||||
|
@ -129,6 +134,11 @@ ZEROCONF = {
|
|||
"name": "smappee2*"
|
||||
}
|
||||
],
|
||||
"_touch-able._tcp.local.": [
|
||||
{
|
||||
"domain": "apple_tv"
|
||||
}
|
||||
],
|
||||
"_viziocast._tcp.local.": [
|
||||
{
|
||||
"domain": "vizio"
|
||||
|
|
|
@ -1283,7 +1283,7 @@ pyatmo==4.2.1
|
|||
pyatome==0.1.1
|
||||
|
||||
# homeassistant.components.apple_tv
|
||||
pyatv==0.3.13
|
||||
pyatv==0.7.3
|
||||
|
||||
# homeassistant.components.bbox
|
||||
pybbox==0.0.5-alpha
|
||||
|
|
|
@ -648,6 +648,9 @@ pyatag==0.3.4.4
|
|||
# homeassistant.components.netatmo
|
||||
pyatmo==4.2.1
|
||||
|
||||
# homeassistant.components.apple_tv
|
||||
pyatv==0.7.3
|
||||
|
||||
# homeassistant.components.blackbird
|
||||
pyblackbird==0.5
|
||||
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
"""Tests for Apple TV."""
|
||||
import pytest
|
||||
|
||||
# Make asserts in the common module display differences
|
||||
pytest.register_assert_rewrite("tests.components.apple_tv.common")
|
|
@ -0,0 +1,49 @@
|
|||
"""Test code shared between test files."""
|
||||
|
||||
from pyatv import conf, interface
|
||||
from pyatv.const import Protocol
|
||||
|
||||
|
||||
class MockPairingHandler(interface.PairingHandler):
|
||||
"""Mock for PairingHandler in pyatv."""
|
||||
|
||||
def __init__(self, *args):
|
||||
"""Initialize a new MockPairingHandler."""
|
||||
super().__init__(*args)
|
||||
self.pin_code = None
|
||||
self.paired = False
|
||||
self.always_fail = False
|
||||
|
||||
def pin(self, pin):
|
||||
"""Pin code used for pairing."""
|
||||
self.pin_code = pin
|
||||
self.paired = False
|
||||
|
||||
@property
|
||||
def device_provides_pin(self):
|
||||
"""Return True if remote device presents PIN code, else False."""
|
||||
return self.service.protocol in [Protocol.MRP, Protocol.AirPlay]
|
||||
|
||||
@property
|
||||
def has_paired(self):
|
||||
"""If a successful pairing has been performed.
|
||||
|
||||
The value will be reset when stop() is called.
|
||||
"""
|
||||
return not self.always_fail and self.paired
|
||||
|
||||
async def begin(self):
|
||||
"""Start pairing process."""
|
||||
|
||||
async def finish(self):
|
||||
"""Stop pairing process."""
|
||||
self.paired = True
|
||||
self.service.credentials = self.service.protocol.name.lower() + "_creds"
|
||||
|
||||
|
||||
def create_conf(name, address, *services):
|
||||
"""Create an Apple TV configuration."""
|
||||
atv = conf.AppleTV(name, address)
|
||||
for service in services:
|
||||
atv.add_service(service)
|
||||
return atv
|
|
@ -0,0 +1,131 @@
|
|||
"""Fixtures for component."""
|
||||
|
||||
from pyatv import conf, net
|
||||
import pytest
|
||||
|
||||
from .common import MockPairingHandler, create_conf
|
||||
|
||||
from tests.async_mock import patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, name="mock_scan")
|
||||
def mock_scan_fixture():
|
||||
"""Mock pyatv.scan."""
|
||||
with patch("homeassistant.components.apple_tv.config_flow.scan") as mock_scan:
|
||||
|
||||
async def _scan(loop, timeout=5, identifier=None, protocol=None, hosts=None):
|
||||
if not mock_scan.hosts:
|
||||
mock_scan.hosts = hosts
|
||||
return mock_scan.result
|
||||
|
||||
mock_scan.result = []
|
||||
mock_scan.hosts = None
|
||||
mock_scan.side_effect = _scan
|
||||
yield mock_scan
|
||||
|
||||
|
||||
@pytest.fixture(name="dmap_pin")
|
||||
def dmap_pin_fixture():
|
||||
"""Mock pyatv.scan."""
|
||||
with patch("homeassistant.components.apple_tv.config_flow.randrange") as mock_pin:
|
||||
mock_pin.side_effect = lambda start, stop: 1111
|
||||
yield mock_pin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pairing():
|
||||
"""Mock pyatv.scan."""
|
||||
with patch("homeassistant.components.apple_tv.config_flow.pair") as mock_pair:
|
||||
|
||||
async def _pair(config, protocol, loop, session=None, **kwargs):
|
||||
handler = MockPairingHandler(
|
||||
await net.create_session(session), config.get_service(protocol)
|
||||
)
|
||||
handler.always_fail = mock_pair.always_fail
|
||||
return handler
|
||||
|
||||
mock_pair.always_fail = False
|
||||
mock_pair.side_effect = _pair
|
||||
yield mock_pair
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pairing_mock():
|
||||
"""Mock pyatv.scan."""
|
||||
with patch("homeassistant.components.apple_tv.config_flow.pair") as mock_pair:
|
||||
|
||||
async def _pair(config, protocol, loop, session=None, **kwargs):
|
||||
return mock_pair
|
||||
|
||||
async def _begin():
|
||||
pass
|
||||
|
||||
async def _close():
|
||||
pass
|
||||
|
||||
mock_pair.close.side_effect = _close
|
||||
mock_pair.begin.side_effect = _begin
|
||||
mock_pair.pin = lambda pin: None
|
||||
mock_pair.side_effect = _pair
|
||||
yield mock_pair
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def full_device(mock_scan, dmap_pin):
|
||||
"""Mock pyatv.scan."""
|
||||
mock_scan.result.append(
|
||||
create_conf(
|
||||
"127.0.0.1",
|
||||
"MRP Device",
|
||||
conf.MrpService("mrpid", 5555),
|
||||
conf.DmapService("dmapid", None, port=6666),
|
||||
conf.AirPlayService("airplayid", port=7777),
|
||||
)
|
||||
)
|
||||
yield mock_scan
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mrp_device(mock_scan):
|
||||
"""Mock pyatv.scan."""
|
||||
mock_scan.result.append(
|
||||
create_conf("127.0.0.1", "MRP Device", conf.MrpService("mrpid", 5555))
|
||||
)
|
||||
yield mock_scan
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dmap_device(mock_scan):
|
||||
"""Mock pyatv.scan."""
|
||||
mock_scan.result.append(
|
||||
create_conf(
|
||||
"127.0.0.1",
|
||||
"DMAP Device",
|
||||
conf.DmapService("dmapid", None, port=6666),
|
||||
)
|
||||
)
|
||||
yield mock_scan
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dmap_device_with_credentials(mock_scan):
|
||||
"""Mock pyatv.scan."""
|
||||
mock_scan.result.append(
|
||||
create_conf(
|
||||
"127.0.0.1",
|
||||
"DMAP Device",
|
||||
conf.DmapService("dmapid", "dummy_creds", port=6666),
|
||||
)
|
||||
)
|
||||
yield mock_scan
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def airplay_device(mock_scan):
|
||||
"""Mock pyatv.scan."""
|
||||
mock_scan.result.append(
|
||||
create_conf(
|
||||
"127.0.0.1", "AirPlay Device", conf.AirPlayService("airplayid", port=7777)
|
||||
)
|
||||
)
|
||||
yield mock_scan
|
|
@ -0,0 +1,582 @@
|
|||
"""Test config flow."""
|
||||
|
||||
from pyatv import exceptions
|
||||
from pyatv.const import Protocol
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.components.apple_tv.const import CONF_START_OFF, DOMAIN
|
||||
|
||||
from tests.async_mock import patch
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
DMAP_SERVICE = {
|
||||
"type": "_touch-able._tcp.local.",
|
||||
"name": "dmapid.something",
|
||||
"properties": {"CtlN": "Apple TV"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_setup_entry():
|
||||
"""Mock setting up a config entry."""
|
||||
with patch(
|
||||
"homeassistant.components.apple_tv.async_setup_entry", return_value=True
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# User Flows
|
||||
|
||||
|
||||
async def test_user_input_device_not_found(hass, mrp_device):
|
||||
"""Test when user specifies a non-existing device."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["description_placeholders"] == {"devices": "`MRP Device (127.0.0.1)`"}
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "none"},
|
||||
)
|
||||
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["errors"] == {"base": "no_devices_found"}
|
||||
|
||||
|
||||
async def test_user_input_unexpected_error(hass, mock_scan):
|
||||
"""Test that unexpected error yields an error message."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
mock_scan.side_effect = Exception
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "dummy"},
|
||||
)
|
||||
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["errors"] == {"base": "unknown"}
|
||||
|
||||
|
||||
async def test_user_adds_full_device(hass, full_device, pairing):
|
||||
"""Test adding device with all services."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["errors"] == {}
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "MRP Device"},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["description_placeholders"] == {"name": "MRP Device"}
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result3["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result3["description_placeholders"] == {"protocol": "MRP"}
|
||||
|
||||
result4 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"pin": 1111}
|
||||
)
|
||||
assert result4["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result4["description_placeholders"] == {"protocol": "DMAP", "pin": 1111}
|
||||
|
||||
result5 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result5["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result5["description_placeholders"] == {"protocol": "AirPlay"}
|
||||
|
||||
result6 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"pin": 1234}
|
||||
)
|
||||
assert result6["type"] == "create_entry"
|
||||
assert result6["data"] == {
|
||||
"address": "127.0.0.1",
|
||||
"credentials": {
|
||||
Protocol.DMAP.value: "dmap_creds",
|
||||
Protocol.MRP.value: "mrp_creds",
|
||||
Protocol.AirPlay.value: "airplay_creds",
|
||||
},
|
||||
"name": "MRP Device",
|
||||
"protocol": Protocol.MRP.value,
|
||||
}
|
||||
|
||||
|
||||
async def test_user_adds_dmap_device(hass, dmap_device, dmap_pin, pairing):
|
||||
"""Test adding device with only DMAP service."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "DMAP Device"},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["description_placeholders"] == {"name": "DMAP Device"}
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result3["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result3["description_placeholders"] == {"pin": 1111, "protocol": "DMAP"}
|
||||
|
||||
result6 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"pin": 1234}
|
||||
)
|
||||
assert result6["type"] == "create_entry"
|
||||
assert result6["data"] == {
|
||||
"address": "127.0.0.1",
|
||||
"credentials": {Protocol.DMAP.value: "dmap_creds"},
|
||||
"name": "DMAP Device",
|
||||
"protocol": Protocol.DMAP.value,
|
||||
}
|
||||
|
||||
|
||||
async def test_user_adds_dmap_device_failed(hass, dmap_device, dmap_pin, pairing):
|
||||
"""Test adding DMAP device where remote device did not attempt to pair."""
|
||||
pairing.always_fail = True
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "DMAP Device"},
|
||||
)
|
||||
|
||||
await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result2["reason"] == "device_did_not_pair"
|
||||
|
||||
|
||||
async def test_user_adds_device_with_credentials(hass, dmap_device_with_credentials):
|
||||
"""Test adding DMAP device with existing credentials (home sharing)."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "DMAP Device"},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["description_placeholders"] == {"name": "DMAP Device"}
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result3["type"] == "create_entry"
|
||||
assert result3["data"] == {
|
||||
"address": "127.0.0.1",
|
||||
"credentials": {Protocol.DMAP.value: "dummy_creds"},
|
||||
"name": "DMAP Device",
|
||||
"protocol": Protocol.DMAP.value,
|
||||
}
|
||||
|
||||
|
||||
async def test_user_adds_device_with_ip_filter(
|
||||
hass, dmap_device_with_credentials, mock_scan
|
||||
):
|
||||
"""Test add device filtering by IP."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "127.0.0.1"},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["description_placeholders"] == {"name": "DMAP Device"}
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result3["type"] == "create_entry"
|
||||
assert result3["data"] == {
|
||||
"address": "127.0.0.1",
|
||||
"credentials": {Protocol.DMAP.value: "dummy_creds"},
|
||||
"name": "DMAP Device",
|
||||
"protocol": Protocol.DMAP.value,
|
||||
}
|
||||
|
||||
|
||||
async def test_user_adds_device_by_ip_uses_unicast_scan(hass, mock_scan):
|
||||
"""Test add device by IP-address, verify unicast scan is used."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "127.0.0.1"},
|
||||
)
|
||||
|
||||
assert str(mock_scan.hosts[0]) == "127.0.0.1"
|
||||
|
||||
|
||||
async def test_user_adds_existing_device(hass, mrp_device):
|
||||
"""Test that it is not possible to add existing device."""
|
||||
MockConfigEntry(domain="apple_tv", unique_id="mrpid").add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "127.0.0.1"},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["errors"] == {"base": "already_configured"}
|
||||
|
||||
|
||||
async def test_user_adds_unusable_device(hass, airplay_device):
|
||||
"""Test that it is not possible to add pure AirPlay device."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "AirPlay Device"},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["errors"] == {"base": "no_usable_service"}
|
||||
|
||||
|
||||
async def test_user_connection_failed(hass, mrp_device, pairing_mock):
|
||||
"""Test error message when connection to device fails."""
|
||||
pairing_mock.begin.side_effect = exceptions.ConnectionFailedError
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "MRP Device"},
|
||||
)
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{},
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result2["reason"] == "invalid_config"
|
||||
|
||||
|
||||
async def test_user_start_pair_error_failed(hass, mrp_device, pairing_mock):
|
||||
"""Test initiating pairing fails."""
|
||||
pairing_mock.begin.side_effect = exceptions.PairingError
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "MRP Device"},
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result2["reason"] == "invalid_auth"
|
||||
|
||||
|
||||
async def test_user_pair_invalid_pin(hass, mrp_device, pairing_mock):
|
||||
"""Test pairing with invalid pin."""
|
||||
pairing_mock.finish.side_effect = exceptions.PairingError
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "MRP Device"},
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{},
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"pin": 1111},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["errors"] == {"base": "invalid_auth"}
|
||||
|
||||
|
||||
async def test_user_pair_unexpected_error(hass, mrp_device, pairing_mock):
|
||||
"""Test unexpected error when entering PIN code."""
|
||||
|
||||
pairing_mock.finish.side_effect = Exception
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "MRP Device"},
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{},
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"pin": 1111},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["errors"] == {"base": "unknown"}
|
||||
|
||||
|
||||
async def test_user_pair_backoff_error(hass, mrp_device, pairing_mock):
|
||||
"""Test that backoff error is displayed in case device requests it."""
|
||||
pairing_mock.begin.side_effect = exceptions.BackOffError
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "MRP Device"},
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result2["reason"] == "backoff"
|
||||
|
||||
|
||||
async def test_user_pair_begin_unexpected_error(hass, mrp_device, pairing_mock):
|
||||
"""Test unexpected error during start of pairing."""
|
||||
pairing_mock.begin.side_effect = Exception
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"device_input": "MRP Device"},
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result2["reason"] == "unknown"
|
||||
|
||||
|
||||
# Zeroconf
|
||||
|
||||
|
||||
async def test_zeroconf_unsupported_service_aborts(hass):
|
||||
"""Test discovering unsupported zeroconf service."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_ZEROCONF},
|
||||
data={
|
||||
"type": "_dummy._tcp.local.",
|
||||
"properties": {},
|
||||
},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "unknown"
|
||||
|
||||
|
||||
async def test_zeroconf_add_mrp_device(hass, mrp_device, pairing):
|
||||
"""Test add MRP device discovered by zeroconf."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_ZEROCONF},
|
||||
data={
|
||||
"type": "_mediaremotetv._tcp.local.",
|
||||
"properties": {"UniqueIdentifier": "mrpid", "Name": "Kitchen"},
|
||||
},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["description_placeholders"] == {"name": "MRP Device"}
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["description_placeholders"] == {"protocol": "MRP"}
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"pin": 1111}
|
||||
)
|
||||
assert result3["type"] == "create_entry"
|
||||
assert result3["data"] == {
|
||||
"address": "127.0.0.1",
|
||||
"credentials": {Protocol.MRP.value: "mrp_creds"},
|
||||
"name": "MRP Device",
|
||||
"protocol": Protocol.MRP.value,
|
||||
}
|
||||
|
||||
|
||||
async def test_zeroconf_add_dmap_device(hass, dmap_device, dmap_pin, pairing):
|
||||
"""Test add DMAP device discovered by zeroconf."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["description_placeholders"] == {"name": "DMAP Device"}
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["description_placeholders"] == {"protocol": "DMAP", "pin": 1111}
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result3["type"] == "create_entry"
|
||||
assert result3["data"] == {
|
||||
"address": "127.0.0.1",
|
||||
"credentials": {Protocol.DMAP.value: "dmap_creds"},
|
||||
"name": "DMAP Device",
|
||||
"protocol": Protocol.DMAP.value,
|
||||
}
|
||||
|
||||
|
||||
async def test_zeroconf_add_existing_aborts(hass, dmap_device):
|
||||
"""Test start new zeroconf flow while existing flow is active aborts."""
|
||||
await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
|
||||
)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "already_in_progress"
|
||||
|
||||
|
||||
async def test_zeroconf_add_but_device_not_found(hass, mock_scan):
|
||||
"""Test add device which is not found with another scan."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "no_devices_found"
|
||||
|
||||
|
||||
async def test_zeroconf_add_existing_device(hass, dmap_device):
|
||||
"""Test add already existing device from zeroconf."""
|
||||
MockConfigEntry(domain="apple_tv", unique_id="dmapid").add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
|
||||
async def test_zeroconf_unexpected_error(hass, mock_scan):
|
||||
"""Test unexpected error aborts in zeroconf."""
|
||||
mock_scan.side_effect = Exception
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=DMAP_SERVICE
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "unknown"
|
||||
|
||||
|
||||
# Re-configuration
|
||||
|
||||
|
||||
async def test_reconfigure_update_credentials(hass, mrp_device, pairing):
|
||||
"""Test that reconfigure flow updates config entry."""
|
||||
config_entry = MockConfigEntry(domain="apple_tv", unique_id="mrpid")
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": "reauth"},
|
||||
data={"identifier": "mrpid", "name": "apple tv"},
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{},
|
||||
)
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result2["description_placeholders"] == {"protocol": "MRP"}
|
||||
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"pin": 1111}
|
||||
)
|
||||
assert result3["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result3["reason"] == "already_configured"
|
||||
|
||||
assert config_entry.data == {
|
||||
"address": "127.0.0.1",
|
||||
"protocol": Protocol.MRP.value,
|
||||
"name": "MRP Device",
|
||||
"credentials": {Protocol.MRP.value: "mrp_creds"},
|
||||
}
|
||||
|
||||
|
||||
async def test_reconfigure_ongoing_aborts(hass, mrp_device):
|
||||
"""Test start additional reconfigure flow aborts."""
|
||||
data = {
|
||||
"identifier": "mrpid",
|
||||
"name": "Apple TV",
|
||||
}
|
||||
|
||||
await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": "reauth"}, data=data
|
||||
)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": "reauth"}, data=data
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "already_in_progress"
|
||||
|
||||
|
||||
# Options
|
||||
|
||||
|
||||
async def test_option_start_off(hass):
|
||||
"""Test start off-option flag."""
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN, unique_id="dmapid", options={"start_off": False}
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
|
||||
result2 = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"], user_input={CONF_START_OFF: True}
|
||||
)
|
||||
assert result2["type"] == "create_entry"
|
||||
|
||||
assert config_entry.options[CONF_START_OFF]
|
Loading…
Reference in New Issue