Reduce cpu requirements for apple_tv mdns and discovery (#61346)

Co-authored-by: jjlawren <jjlawren@users.noreply.github.com>
pull/61538/head
J. Nick Koston 2021-12-11 19:57:11 -10:00 committed by GitHub
parent 9a1109949f
commit 388fcac689
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 228 additions and 76 deletions

View File

@ -36,6 +36,7 @@ _LOGGER = logging.getLogger(__name__)
DEFAULT_NAME = "Apple TV"
BACKOFF_TIME_LOWER_LIMIT = 15 # seconds
BACKOFF_TIME_UPPER_LIMIT = 300 # Five minutes
SIGNAL_CONNECTED = "apple_tv_connected"
@ -241,7 +242,11 @@ class AppleTVManager:
if self.atv is None:
self._connection_attempts += 1
backoff = min(
randrange(2 ** self._connection_attempts), BACKOFF_TIME_UPPER_LIMIT
max(
BACKOFF_TIME_LOWER_LIMIT,
randrange(2 ** self._connection_attempts),
),
BACKOFF_TIME_UPPER_LIMIT,
)
_LOGGER.debug("Reconnecting in %d seconds", backoff)
@ -271,17 +276,12 @@ class AppleTVManager:
return atvs[0]
_LOGGER.debug(
"Failed to find device %s with address %s, trying to scan",
"Failed to find device %s with address %s",
self.config_entry.title,
address,
)
atvs = await scan(self.hass.loop, identifier=identifiers, protocol=protocols)
if atvs:
return atvs[0]
_LOGGER.debug("Failed to find device %s, trying later", self.config_entry.title)
# We no longer multicast scan for the device since as soon as async_step_zeroconf runs,
# it will update the address and reload the config entry when the device is found.
return None
async def _connect(self, conf):

View File

@ -1,4 +1,5 @@
"""Config flow for Apple TV integration."""
import asyncio
from collections import deque
from ipaddress import ip_address
import logging
@ -27,6 +28,8 @@ INPUT_PIN_SCHEMA = vol.Schema({vol.Required(CONF_PIN, default=None): int})
DEFAULT_START_OFF = False
DISCOVERY_AGGREGATION_TIME = 15 # seconds
async def device_scan(identifier, loop):
"""Scan for a specific device using identifier as filter."""
@ -46,12 +49,13 @@ async def device_scan(identifier, loop):
except ValueError:
return None
for hosts in (_host_filter(), None):
scan_result = await scan(loop, timeout=3, hosts=hosts)
matches = [atv for atv in scan_result if _filter_device(atv)]
# If we have an address, only probe that address to avoid
# broadcast traffic on the network
scan_result = await scan(loop, timeout=3, hosts=_host_filter())
matches = [atv for atv in scan_result if _filter_device(atv)]
if matches:
return matches[0], matches[0].all_identifiers
if matches:
return matches[0], matches[0].all_identifiers
return None, None
@ -93,10 +97,12 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
existing config entry. If that's the case, the unique_id from that entry is
re-used, otherwise the newly discovered identifier is used instead.
"""
all_identifiers = set(self.atv.all_identifiers)
for entry in self._async_current_entries():
for identifier in self.atv.all_identifiers:
if identifier in entry.data.get(CONF_IDENTIFIERS, [entry.unique_id]):
return entry.unique_id
if all_identifiers.intersection(
entry.data.get(CONF_IDENTIFIERS, [entry.unique_id])
):
return entry.unique_id
return self.atv.identifier
async def async_step_reauth(self, user_input=None):
@ -149,22 +155,18 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self, discovery_info: zeroconf.ZeroconfServiceInfo
) -> data_entry_flow.FlowResult:
"""Handle device found via zeroconf."""
host = discovery_info.host
self._async_abort_entries_match({CONF_ADDRESS: host})
service_type = discovery_info.type[:-1] # Remove leading .
name = discovery_info.name.replace(f".{service_type}.", "")
properties = discovery_info.properties
# Extract unique identifier from service
self.scan_filter = get_unique_id(service_type, name, properties)
if self.scan_filter is None:
unique_id = get_unique_id(service_type, name, properties)
if unique_id is None:
return self.async_abort(reason="unknown")
# Scan for the device in order to extract _all_ unique identifiers assigned to
# it. Not doing it like this will yield multiple config flows for the same
# device, one per protocol, which is undesired.
return await self.async_find_device_wrapper(self.async_found_zeroconf_device)
async def async_found_zeroconf_device(self, user_input=None):
"""Handle device found after Zeroconf discovery."""
#
# Suppose we have a device with three services: A, B and C. Let's assume
# service A is discovered by Zeroconf, triggering a device scan that also finds
# service B but *not* C. An identifier is picked from one of the services and
@ -177,31 +179,63 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
# since both flows really represent the same device. They will however end up
# as two separate flows.
#
# To solve this, all identifiers found during a device scan is stored as
# To solve this, all identifiers are stored as
# "all_identifiers" in the flow context. When a new service is discovered, the
# code below will check these identifiers for all active flows and abort if a
# match is found. Before aborting, the original flow is updated with any
# potentially new identifiers. In the example above, when service C is
# discovered, the identifier of service C will be inserted into
# "all_identifiers" of the original flow (making the device complete).
for flow in self._async_in_progress():
for identifier in self.atv.all_identifiers:
if identifier not in flow["context"].get("all_identifiers", []):
continue
#
# Wait DISCOVERY_AGGREGATION_TIME for multiple services to be
# discovered via zeroconf. Once the first service is discovered
# this allows other services to be discovered inside the time
# window before triggering a scan of the device. This prevents
# multiple scans of the device at the same time since each
# apple_tv device has multiple services that are discovered by
# zeroconf.
#
await asyncio.sleep(DISCOVERY_AGGREGATION_TIME)
self._async_check_in_progress_and_set_address(host, unique_id)
# Scan for the device in order to extract _all_ unique identifiers assigned to
# it. Not doing it like this will yield multiple config flows for the same
# device, one per protocol, which is undesired.
self.scan_filter = host
return await self.async_find_device_wrapper(self.async_found_zeroconf_device)
@callback
def _async_check_in_progress_and_set_address(self, host: str, unique_id: str):
"""Check for in-progress flows and update them with identifiers if needed.
This code must not await between checking in progress and setting the host
or it will have a race condition where no flows move forward.
"""
for flow in self._async_in_progress(include_uninitialized=True):
context = flow["context"]
if (
context.get("source") != config_entries.SOURCE_ZEROCONF
or context.get(CONF_ADDRESS) != host
):
continue
if (
"all_identifiers" in context
and unique_id not in context["all_identifiers"]
):
# Add potentially new identifiers from this device to the existing flow
identifiers = set(flow["context"]["all_identifiers"])
identifiers.update(self.atv.all_identifiers)
flow["context"]["all_identifiers"] = list(identifiers)
raise data_entry_flow.AbortFlow("already_in_progress")
context["all_identifiers"].append(unique_id)
raise data_entry_flow.AbortFlow("already_in_progress")
self.context[CONF_ADDRESS] = host
async def async_found_zeroconf_device(self, user_input=None):
"""Handle device found after Zeroconf discovery."""
self.context["all_identifiers"] = self.atv.all_identifiers
# Also abort if an integration with this identifier already exists
await self.async_set_unique_id(self.device_identifier)
self._abort_if_unique_id_configured()
# but be sure to update the address if its changed so the scanner
# will probe the new address
self._abort_if_unique_id_configured(updates={CONF_ADDRESS: self.atv.address})
self.context["identifier"] = self.unique_id
return await self.async_step_confirm()
@ -245,14 +279,22 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
else model_str(dev_info.model)
),
}
if not allow_exist:
for identifier in self.atv.all_identifiers:
for entry in self._async_current_entries():
if identifier in entry.data.get(
CONF_IDENTIFIERS, [entry.unique_id]
):
raise DeviceAlreadyConfigured()
all_identifiers = set(self.atv.all_identifiers)
for entry in self._async_current_entries():
if not all_identifiers.intersection(
entry.data.get(CONF_IDENTIFIERS, [entry.unique_id])
):
continue
if entry.data.get(CONF_ADDRESS) != self.atv.address:
self.hass.config_entries.async_update_entry(
entry,
data={**entry.data, CONF_ADDRESS: self.atv.address},
)
self.hass.async_create_task(
self.hass.config_entries.async_reload(entry.entry_id)
)
if not allow_exist:
raise DeviceAlreadyConfigured()
async def async_step_confirm(self, user_input=None):
"""Handle user-confirmation of discovered node."""

View File

@ -49,10 +49,10 @@ def create_conf(name, address, *services):
return atv
def mrp_service(enabled=True):
def mrp_service(enabled=True, unique_id="mrpid"):
"""Create example MRP service."""
return conf.ManualService(
"mrpid",
unique_id,
Protocol.MRP,
5555,
{},
@ -70,3 +70,14 @@ def airplay_service():
{},
pairing_requirement=const.PairingRequirement.Mandatory,
)
def raop_service():
"""Create example RAOP service."""
return conf.ManualService(
"AABBCCDDEEFF",
Protocol.RAOP,
7000,
{},
pairing_requirement=const.PairingRequirement.Mandatory,
)

View File

@ -96,12 +96,19 @@ def full_device(mock_scan, dmap_pin):
@pytest.fixture
def mrp_device(mock_scan):
"""Mock pyatv.scan."""
mock_scan.result.append(
create_conf(
"127.0.0.1",
"MRP Device",
mrp_service(),
)
mock_scan.result.extend(
[
create_conf(
"127.0.0.1",
"MRP Device",
mrp_service(),
),
create_conf(
"127.0.0.2",
"MRP Device 2",
mrp_service(unique_id="unrelated"),
),
]
)
yield mock_scan

View File

@ -1,6 +1,6 @@
"""Test config flow."""
from unittest.mock import patch
from unittest.mock import ANY, patch
from pyatv import exceptions
from pyatv.const import PairingRequirement, Protocol
@ -8,14 +8,15 @@ import pytest
from homeassistant import config_entries, data_entry_flow
from homeassistant.components import zeroconf
from homeassistant.components.apple_tv import CONF_ADDRESS, config_flow
from homeassistant.components.apple_tv.const import CONF_START_OFF, DOMAIN
from .common import airplay_service, create_conf, mrp_service
from .common import airplay_service, create_conf, mrp_service, raop_service
from tests.common import MockConfigEntry
DMAP_SERVICE = zeroconf.ZeroconfServiceInfo(
host="mock_host",
host="127.0.0.1",
hostname="mock_hostname",
port=None,
type="_touch-able._tcp.local.",
@ -24,6 +25,23 @@ DMAP_SERVICE = zeroconf.ZeroconfServiceInfo(
)
RAOP_SERVICE = zeroconf.ZeroconfServiceInfo(
host="127.0.0.1",
hostname="mock_hostname",
port=None,
type="_raop._tcp.local.",
name="AABBCCDDEEFF@Master Bed._raop._tcp.local.",
properties={"am": "AppleTV11,1"},
)
@pytest.fixture(autouse=True)
def zero_aggregation_time():
"""Prevent the aggregation time from delaying the tests."""
with patch.object(config_flow, "DISCOVERY_AGGREGATION_TIME", 0):
yield
@pytest.fixture(autouse=True)
def use_mocked_zeroconf(mock_async_zeroconf):
"""Mock zeroconf in all tests."""
@ -507,7 +525,7 @@ async def test_zeroconf_unsupported_service_aborts(hass):
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
host="mock_host",
host="127.0.0.1",
hostname="mock_hostname",
name="mock_name",
port=None,
@ -521,11 +539,25 @@ async def test_zeroconf_unsupported_service_aborts(hass):
async def test_zeroconf_add_mrp_device(hass, mrp_device, pairing):
"""Test add MRP device discovered by zeroconf."""
unrelated_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
host="127.0.0.2",
hostname="mock_hostname",
port=None,
name="Kitchen",
properties={"UniqueIdentifier": "unrelated", "Name": "Kitchen"},
type="_mediaremotetv._tcp.local.",
),
)
assert unrelated_result["type"] == data_entry_flow.RESULT_TYPE_FORM
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
host="mock_host",
host="127.0.0.1",
hostname="mock_hostname",
port=None,
name="Kitchen",
@ -586,6 +618,37 @@ async def test_zeroconf_add_dmap_device(hass, dmap_device, dmap_pin, pairing):
}
async def test_zeroconf_ip_change(hass, mock_scan):
"""Test that the config entry gets updated when the ip changes and reloads."""
entry = MockConfigEntry(
domain="apple_tv", unique_id="mrpid", data={CONF_ADDRESS: "127.0.0.2"}
)
unrelated_entry = MockConfigEntry(
domain="apple_tv", unique_id="unrelated", data={CONF_ADDRESS: "127.0.0.2"}
)
unrelated_entry.add_to_hass(hass)
entry.add_to_hass(hass)
mock_scan.result = [
create_conf("127.0.0.1", "Device", mrp_service(), airplay_service())
]
with patch(
"homeassistant.components.apple_tv.async_setup_entry", return_value=True
) as mock_async_setup:
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=DMAP_SERVICE,
)
await hass.async_block_till_done()
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "already_configured"
assert len(mock_async_setup.mock_calls) == 2
assert entry.data[CONF_ADDRESS] == "127.0.0.1"
assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2"
async def test_zeroconf_add_existing_aborts(hass, dmap_device):
"""Test start new zeroconf flow while existing flow is active aborts."""
await hass.config_entries.flow.async_init(
@ -638,7 +701,7 @@ async def test_zeroconf_abort_if_other_in_progress(hass, mock_scan):
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
host="mock_host",
host="127.0.0.1",
hostname="mock_hostname",
port=None,
type="_airplay._tcp.local.",
@ -658,7 +721,7 @@ async def test_zeroconf_abort_if_other_in_progress(hass, mock_scan):
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
host="mock_host",
host="127.0.0.1",
hostname="mock_hostname",
port=None,
type="_mediaremotetv._tcp.local.",
@ -681,7 +744,7 @@ async def test_zeroconf_missing_device_during_protocol_resolve(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
host="mock_host",
host="127.0.0.1",
hostname="mock_hostname",
port=None,
type="_airplay._tcp.local.",
@ -700,7 +763,7 @@ async def test_zeroconf_missing_device_during_protocol_resolve(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
host="mock_host",
host="127.0.0.1",
hostname="mock_hostname",
port=None,
type="_mediaremotetv._tcp.local.",
@ -733,7 +796,7 @@ async def test_zeroconf_additional_protocol_resolve_failure(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
host="mock_host",
host="127.0.0.1",
hostname="mock_hostname",
port=None,
type="_airplay._tcp.local.",
@ -752,7 +815,7 @@ async def test_zeroconf_additional_protocol_resolve_failure(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
host="mock_host",
host="127.0.0.1",
hostname="mock_hostname",
port=None,
type="_mediaremotetv._tcp.local.",
@ -785,7 +848,7 @@ async def test_zeroconf_pair_additionally_found_protocols(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
host="mock_host",
host="127.0.0.1",
hostname="mock_hostname",
port=None,
type="_airplay._tcp.local.",
@ -793,9 +856,26 @@ async def test_zeroconf_pair_additionally_found_protocols(
properties={"deviceid": "airplayid"},
),
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
await hass.async_block_till_done()
mock_scan.result = [
create_conf("127.0.0.1", "Device", mrp_service(), airplay_service())
create_conf("127.0.0.1", "Device", raop_service(), airplay_service())
]
# Find the same device again, but now also with RAOP service. The first flow should
# be updated with the RAOP service.
await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=RAOP_SERVICE,
)
await hass.async_block_till_done()
mock_scan.result = [
create_conf(
"127.0.0.1", "Device", raop_service(), mrp_service(), airplay_service()
)
]
# Find the same device again, but now also with MRP service. The first flow should
@ -804,7 +884,7 @@ async def test_zeroconf_pair_additionally_found_protocols(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
host="mock_host",
host="127.0.0.1",
hostname="mock_hostname",
port=None,
type="_mediaremotetv._tcp.local.",
@ -812,29 +892,41 @@ async def test_zeroconf_pair_additionally_found_protocols(
properties={"UniqueIdentifier": "mrpid", "Name": "Kitchen"},
),
)
await hass.async_block_till_done()
# Verify that _both_ protocols are paired
# Verify that all protocols are paired
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["step_id"] == "pair_with_pin"
assert result2["description_placeholders"] == {"protocol": "MRP"}
assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result2["step_id"] == "pair_no_pin"
assert result2["description_placeholders"] == {"pin": ANY, "protocol": "RAOP"}
# Verify that all protocols are paired
result3 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"pin": 1234},
{},
)
assert result3["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result3["step_id"] == "pair_with_pin"
assert result3["description_placeholders"] == {"protocol": "AirPlay"}
assert result3["description_placeholders"] == {"protocol": "MRP"}
result4 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"pin": 1234},
)
assert result4["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result4["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result4["step_id"] == "pair_with_pin"
assert result4["description_placeholders"] == {"protocol": "AirPlay"}
result5 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"pin": 1234},
)
assert result5["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
# Re-configuration