Improve apple_tv typing (#107694)

pull/108008/head
J. Nick Koston 2024-01-13 22:37:04 -10:00 committed by GitHub
parent 4b8d8baa69
commit 93d363ea57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 30 deletions

View File

@ -2,10 +2,13 @@
import asyncio
import logging
from random import randrange
from typing import TYPE_CHECKING, cast
from pyatv import connect, exceptions, scan
from pyatv.conf import AppleTV
from pyatv.const import DeviceModel, Protocol
from pyatv.convert import model_str
from pyatv.interface import AppleTV as AppleTVInterface, DeviceListener
from homeassistant.components import zeroconf
from homeassistant.config_entries import ConfigEntry
@ -92,10 +95,14 @@ class AppleTVEntity(Entity):
_attr_has_entity_name = True
_attr_name = None
def __init__(self, name, identifier, manager):
def __init__(
self, name: str, identifier: str | None, manager: "AppleTVManager"
) -> None:
"""Initialize device."""
self.atv = None
self.atv: AppleTVInterface = None # type: ignore[assignment]
self.manager = manager
if TYPE_CHECKING:
assert identifier is not None
self._attr_unique_id = identifier
self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, identifier)},
@ -143,7 +150,7 @@ class AppleTVEntity(Entity):
"""Handle when connection was lost to device."""
class AppleTVManager:
class AppleTVManager(DeviceListener):
"""Connection and power manager for an Apple TV.
An instance is used per device to share the same power state between
@ -151,11 +158,11 @@ class AppleTVManager:
in case of problems.
"""
def __init__(self, hass, config_entry):
def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Initialize power manager."""
self.config_entry = config_entry
self.hass = hass
self.atv = None
self.atv: AppleTVInterface | None = None
self.is_on = not config_entry.options.get(CONF_START_OFF, False)
self._connection_attempts = 0
self._connection_was_lost = False
@ -220,7 +227,7 @@ class AppleTVManager:
"Not starting connect loop (%s, %s)", self.atv is None, self.is_on
)
async def connect_once(self, raise_missing_credentials):
async def connect_once(self, raise_missing_credentials: bool) -> None:
"""Try to connect once."""
try:
if conf := await self._scan():
@ -264,49 +271,51 @@ class AppleTVManager:
_LOGGER.debug("Connect loop ended")
self._task = None
async def _scan(self):
async def _scan(self) -> AppleTV | None:
"""Try to find device by scanning for it."""
identifiers = set(
self.config_entry.data.get(CONF_IDENTIFIERS, [self.config_entry.unique_id])
config_entry = self.config_entry
identifiers: set[str] = set(
config_entry.data.get(CONF_IDENTIFIERS, [config_entry.unique_id])
)
address = self.config_entry.data[CONF_ADDRESS]
address: str = config_entry.data[CONF_ADDRESS]
hass = self.hass
# Only scan for and set up protocols that was successfully paired
protocols = {
Protocol(int(protocol))
for protocol in self.config_entry.data[CONF_CREDENTIALS]
Protocol(int(protocol)) for protocol in config_entry.data[CONF_CREDENTIALS]
}
_LOGGER.debug("Discovering device %s", self.config_entry.title)
aiozc = await zeroconf.async_get_async_instance(self.hass)
_LOGGER.debug("Discovering device %s", config_entry.title)
aiozc = await zeroconf.async_get_async_instance(hass)
atvs = await scan(
self.hass.loop,
hass.loop,
identifier=identifiers,
protocol=protocols,
hosts=[address],
aiozc=aiozc,
)
if atvs:
return atvs[0]
return cast(AppleTV, atvs[0])
_LOGGER.debug(
"Failed to find device %s with address %s",
self.config_entry.title,
config_entry.title,
address,
)
# We no longer multicast scan for the device since as soon as async_step_zeroconf runs,
# it will update the address and reload the config entry when the device is found.
return None
async def _connect(self, conf, raise_missing_credentials):
async def _connect(self, conf: AppleTV, raise_missing_credentials: bool) -> None:
"""Connect to device."""
credentials = self.config_entry.data[CONF_CREDENTIALS]
name = self.config_entry.data[CONF_NAME]
config_entry = self.config_entry
credentials: dict[int, str | None] = config_entry.data[CONF_CREDENTIALS]
name: str = config_entry.data[CONF_NAME]
missing_protocols = []
for protocol_int, creds in credentials.items():
protocol = Protocol(int(protocol_int))
if conf.get_service(protocol) is not None:
conf.set_credentials(protocol, creds)
conf.set_credentials(protocol, creds) # type: ignore[arg-type]
else:
missing_protocols.append(protocol.name)

View File

@ -154,9 +154,9 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
_LOGGER.exception("Failed to update app list")
else:
self._app_list = {
app.name: app.identifier
for app in sorted(apps, key=lambda app: app.name.lower())
if app.name is not None
app_name: app.identifier
for app in sorted(apps, key=lambda app: app_name.lower())
if (app_name := app.name) is not None
}
self.async_write_ha_state()
@ -214,15 +214,19 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
@property
def app_id(self) -> str | None:
"""ID of the current running app."""
if self._is_feature_available(FeatureName.App):
return self.atv.metadata.app.identifier
if self._is_feature_available(FeatureName.App) and (
app := self.atv.metadata.app
):
return app.identifier
return None
@property
def app_name(self) -> str | None:
"""Name of the current running app."""
if self._is_feature_available(FeatureName.App):
return self.atv.metadata.app.name
if self._is_feature_available(FeatureName.App) and (
app := self.atv.metadata.app
):
return app.name
return None
@property
@ -479,7 +483,7 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
async def async_media_seek(self, position: float) -> None:
"""Send seek command."""
if self.atv:
await self.atv.remote_control.set_position(position)
await self.atv.remote_control.set_position(round(position))
async def async_volume_up(self) -> None:
"""Turn volume up for media player."""

View File

@ -81,5 +81,5 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity):
raise ValueError("Command not found. Exiting sequence")
_LOGGER.info("Sending command %s", single_command)
await attr_value()
await attr_value() # type: ignore[operator]
await asyncio.sleep(delay)